{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|default_exp basics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "from fastcore.imports import *\n",
    "import builtins,types,typing\n",
    "import pprint\n",
    "from copy import copy\n",
    "try: from types import UnionType\n",
    "except ImportError: UnionType = None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|hide\n",
    "from __future__ import annotations\n",
    "from fastcore.test import *\n",
    "from nbdev.showdoc import *\n",
    "from fastcore.nb_imports import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Basic functionality\n",
    "\n",
    "> Basic functionality used in the fastai library"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Basics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "defaults = SimpleNamespace()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def ifnone(a, b):\n",
    "    \"`b` if `a` is None else `a`\"\n",
    "    return b if a is None else a"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Since `b if a is None else a` is such a common pattern, we wrap it in a function. However, be careful, because python will evaluate *both* `a` and `b` when calling `ifnone` (which it doesn't do if using the `if` version directly)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(ifnone(None,1), 1)\n",
    "test_eq(ifnone(2   ,1), 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def maybe_attr(o, attr):\n",
    "    \"`getattr(o,attr,o)`\"\n",
    "    return getattr(o,attr,o)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Return the attribute `attr` for object `o`.  If the attribute doesn't exist, then return the object `o` instead. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class myobj: myattr='foo'\n",
    "\n",
    "test_eq(maybe_attr(myobj, 'myattr'), 'foo')\n",
    "test_eq(maybe_attr(myobj, 'another_attr'), myobj)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def basic_repr(flds=None):\n",
    "    \"Minimal `__repr__`\"\n",
    "    if isinstance(flds, str): flds = re.split(', *', flds)\n",
    "    flds = list(flds or [])\n",
    "    def _f(self):\n",
    "        res = f'{type(self).__module__}.{type(self).__name__}'\n",
    "        if not flds: return f'<{res}>'\n",
    "        sig = ', '.join(f'{o}={getattr(self,o)!r}' for o in flds)\n",
    "        return f'{res}({sig})'\n",
    "    return _f"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In types which provide rich display functionality in Jupyter, their `__repr__` is also called in order to provide a fallback text representation. Unfortunately, this includes a memory address which changes on every invocation, making it non-deterministic. This causes diffs to get messy and creates conflicts in git. To fix this, put `__repr__=basic_repr()` inside your class."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'<__main__.SomeClass>'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class SomeClass: __repr__=basic_repr()\n",
    "repr(SomeClass())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If you pass a list of attributes (`flds`) of an object, then this will generate a string with the name of each attribute and its corresponding value. The format of this string is `key=value`, where `key` is the name of the attribute, and `value` is the value of the attribute.  For each value, attempt to use the `__name__` attribute, otherwise fall back to using the value's `__repr__` when constructing the string.  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\"__main__.SomeClass(a=1, b='foo')\""
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class SomeClass:\n",
    "    a=1\n",
    "    b='foo'\n",
    "    __repr__=basic_repr('a,b')\n",
    "    __name__='some-class'\n",
    "\n",
    "repr(SomeClass())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\"__main__.AnotherClass(c=__main__.SomeClass(a=1, b='foo'), d='bar')\""
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class AnotherClass:\n",
    "    c=SomeClass()\n",
    "    d='bar'\n",
    "    __repr__=basic_repr(['c', 'd'])\n",
    "\n",
    "repr(AnotherClass())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def is_array(x):\n",
    "    \"`True` if `x` supports `__array__` or `iloc`\"\n",
    "    return hasattr(x,'__array__') or hasattr(x,'iloc')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(True, False)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "is_array(np.array(1)),is_array([1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def listify(o=None, *rest, use_list=False, match=None):\n",
    "    \"Convert `o` to a `list`\"\n",
    "    if rest: o = (o,)+rest\n",
    "    if use_list: res = list(o)\n",
    "    elif o is None: res = []\n",
    "    elif isinstance(o, list): res = o\n",
    "    elif isinstance(o, str) or isinstance(o, bytes) or is_array(o): res = [o]\n",
    "    elif is_iter(o): res = list(o)\n",
    "    else: res = [o]\n",
    "    if match is not None:\n",
    "        if is_coll(match): match = len(match)\n",
    "        if len(res)==1: res = res*match\n",
    "        else: assert len(res)==match, 'Match length mismatch'\n",
    "    return res"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Conversion is designed to \"do what you mean\", e.g:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(listify('hi'), ['hi'])\n",
    "test_eq(listify(b'hi'), [b'hi'])\n",
    "test_eq(listify(array(1)), [array(1)])\n",
    "test_eq(listify(1), [1])\n",
    "test_eq(listify([1,2]), [1,2])\n",
    "test_eq(listify(range(3)), [0,1,2])\n",
    "test_eq(listify(None), [])\n",
    "test_eq(listify(1,2), [1,2])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[array([[0, 1, 2],\n",
       "        [3, 4, 5],\n",
       "        [6, 7, 8]])]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "arr = np.arange(9).reshape(3,3)\n",
    "listify(arr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[array([1, 2])]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "listify(array([1,2]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Generators are turned into lists too:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gen = (o for o in range(3))\n",
    "test_eq(listify(gen), [0,1,2])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Use `match` to provide a length to match:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(listify(1,match=3), [1,1,1])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If `match` is a sequence, it's length is used:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(listify(1,match=range(3)), [1,1,1])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If the listified item is not of length `1`, it must be the same length as `match`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(listify([1,1,1],match=3), [1,1,1])\n",
    "test_fail(lambda: listify([1,1],match=3))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def tuplify(o, use_list=False, match=None):\n",
    "    \"Make `o` a tuple\"\n",
    "    return tuple(listify(o, use_list=use_list, match=match))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(tuplify(None),())\n",
    "test_eq(tuplify([1,2,3]),(1,2,3))\n",
    "test_eq(tuplify(1,match=[1,2,3]),(1,1,1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def true(x):\n",
    "    \"Test whether `x` is truthy; collections with >0 elements are considered `True`\"\n",
    "    try: return bool(len(x))\n",
    "    except: return bool(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[(array(0), False),\n",
       " (array(1), True),\n",
       " (array([0]), True),\n",
       " (array([0, 1]), True),\n",
       " (1, True),\n",
       " (0, False),\n",
       " ('', False),\n",
       " (None, False)]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "[(o,true(o)) for o in\n",
    " (array(0),array(1),array([0]),array([0,1]),1,0,'',None)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class NullType:\n",
    "    \"An object that is `False` and can be called, chained, and indexed\"\n",
    "    def __getattr__(self,*args):return null\n",
    "    def __call__(self,*args, **kwargs):return null\n",
    "    def __getitem__(self, *args):return null\n",
    "    def __bool__(self): return False\n",
    "\n",
    "null = NullType()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "False"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bool(null.hi().there[3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def tonull(x):\n",
    "    \"Convert `None` to `null`\"\n",
    "    return null if x is None else x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "False"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bool(tonull(None).hi().there[3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def get_class(nm, *fld_names, sup=None, doc=None, funcs=None, anno=None, **flds):\n",
    "    \"Dynamically create a class, optionally inheriting from `sup`, containing `fld_names`\"\n",
    "    attrs = {}\n",
    "    if not anno: anno = {}\n",
    "    for f in fld_names:\n",
    "        attrs[f] = None\n",
    "        if f not in anno: anno[f] = typing.Any\n",
    "    for f in listify(funcs): attrs[f.__name__] = f\n",
    "    for k,v in flds.items(): attrs[k] = v\n",
    "    sup = ifnone(sup, ())\n",
    "    if not isinstance(sup, tuple): sup=(sup,)\n",
    "\n",
    "    def _init(self, *args, **kwargs):\n",
    "        for i,v in enumerate(args): setattr(self, list(attrs.keys())[i], v)\n",
    "        for k,v in kwargs.items(): setattr(self,k,v)\n",
    "\n",
    "    attrs['_fields'] = [*fld_names,*flds.keys()]\n",
    "    def _eq(self,b):\n",
    "        return all([getattr(self,k)==getattr(b,k) for k in self._fields])\n",
    "\n",
    "    if not sup: attrs['__repr__'] = basic_repr(attrs['_fields'])\n",
    "    attrs['__init__'] = _init\n",
    "    attrs['__eq__'] = _eq\n",
    "    if anno: attrs['__annotations__'] = anno\n",
    "    res = type(nm, sup, attrs)\n",
    "    if doc is not None: res.__doc__ = doc\n",
    "    return res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "---\n",
       "\n",
       "[source](https://github.com/fastai/fastcore/blob/master/fastcore/basics.py#L105){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
       "\n",
       "### get_class\n",
       "\n",
       ">      get_class (nm, *fld_names, sup=None, doc=None, funcs=None, anno=None,\n",
       ">                 **flds)\n",
       "\n",
       "*Dynamically create a class, optionally inheriting from `sup`, containing `fld_names`*"
      ],
      "text/plain": [
       "---\n",
       "\n",
       "[source](https://github.com/fastai/fastcore/blob/master/fastcore/basics.py#L105){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
       "\n",
       "### get_class\n",
       "\n",
       ">      get_class (nm, *fld_names, sup=None, doc=None, funcs=None, anno=None,\n",
       ">                 **flds)\n",
       "\n",
       "*Dynamically create a class, optionally inheriting from `sup`, containing `fld_names`*"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "show_doc(get_class)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'__main__._t(a=1, b=3)'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "_t = get_class('_t', 'a', b=2, anno={'b':int})\n",
    "t = _t()\n",
    "test_eq(t.a, None)\n",
    "test_eq(t.b, 2)\n",
    "t = _t(1, b=3)\n",
    "test_eq(t.a, 1)\n",
    "test_eq(t.b, 3)\n",
    "t = _t(1, 3)\n",
    "test_eq(t.a, 1)\n",
    "test_eq(t.b, 3)\n",
    "test_eq(t, pickle.loads(pickle.dumps(t)))\n",
    "test_eq(_t.__annotations__, {'b':int, 'a':typing.Any})\n",
    "repr(t)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Most often you'll want to call `mk_class`, since it adds the class to your module. See `mk_class` for more details and examples of use (which also apply to `get_class`)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def mk_class(nm, *fld_names, sup=None, doc=None, funcs=None, mod=None, anno=None, **flds):\n",
    "    \"Create a class using `get_class` and add to the caller's module\"\n",
    "    if mod is None: mod = sys._getframe(1).f_locals\n",
    "    res = get_class(nm, *fld_names, sup=sup, doc=doc, funcs=funcs, anno=anno, **flds)\n",
    "    mod[nm] = res"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Any `kwargs` will be added as class attributes, and `sup` is an optional (tuple of) base classes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mk_class('_t', a=1, sup=dict)\n",
    "t = _t()\n",
    "test_eq(t.a, 1)\n",
    "assert(isinstance(t,dict))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A `__init__` is provided that sets attrs for any `kwargs`, and for any `args` (matching by position to fields), along with a `__repr__` which prints all attrs. The docstring is set to `doc`. You can pass `funcs` which will be added as attrs with the function names."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{}"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def foo(self): return 1\n",
    "mk_class('_t', 'a', sup=dict, doc='test doc', funcs=foo)\n",
    "\n",
    "t = _t(3, b=2)\n",
    "test_eq(t.a, 3)\n",
    "test_eq(t.b, 2)\n",
    "test_eq(t.foo(), 1)\n",
    "test_eq(t.__doc__, 'test doc')\n",
    "t"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def wrap_class(nm, *fld_names, sup=None, doc=None, funcs=None, **flds):\n",
    "    \"Decorator: makes function a method of a new class `nm` passing parameters to `mk_class`\"\n",
    "    def _inner(f):\n",
    "        mk_class(nm, *fld_names, sup=sup, doc=doc, funcs=listify(funcs)+[f], mod=f.__globals__, **flds)\n",
    "        return f\n",
    "    return _inner"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@wrap_class('_t', a=2)\n",
    "def bar(self,x): return x+1\n",
    "\n",
    "t = _t()\n",
    "test_eq(t.a, 2)\n",
    "test_eq(t.bar(3), 4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class ignore_exceptions:\n",
    "    \"Context manager to ignore exceptions\"\n",
    "    def __enter__(self): pass\n",
    "    def __exit__(self, *args): return True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "---\n",
       "\n",
       "[source](https://github.com/fastai/fastcore/blob/master/fastcore/basics.py#L149){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
       "\n",
       "#### ignore_exceptions\n",
       "\n",
       ">      ignore_exceptions ()\n",
       "\n",
       "*Context manager to ignore exceptions*"
      ],
      "text/plain": [
       "---\n",
       "\n",
       "[source](https://github.com/fastai/fastcore/blob/master/fastcore/basics.py#L149){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
       "\n",
       "#### ignore_exceptions\n",
       "\n",
       ">      ignore_exceptions ()\n",
       "\n",
       "*Context manager to ignore exceptions*"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "show_doc(ignore_exceptions, title_level=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with ignore_exceptions(): \n",
    "    # Exception will be ignored\n",
    "    raise Exception"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def exec_local(code, var_name):\n",
    "    \"Call `exec` on `code` and return the var `var_name`\"\n",
    "    loc = {}\n",
    "    exec(code, globals(), loc)\n",
    "    return loc[var_name]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(exec_local(\"a=1\", \"a\"), 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def _risinstance(types, obj):\n",
    "    if any(isinstance(t,str) for t in types):\n",
    "        return any(t.__name__ in types for t in type(obj).__mro__)\n",
    "    return isinstance(obj, types)\n",
    "\n",
    "def risinstance(types, obj=None):\n",
    "    \"Curried `isinstance` but with args reversed\"\n",
    "    types = tuplify(types)\n",
    "    if obj is None: return partial(_risinstance,types)\n",
    "    return _risinstance(types, obj)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert risinstance(int, 1)\n",
    "assert not risinstance(str, 0)\n",
    "assert risinstance(int)(1)\n",
    "assert not risinstance(int)(None)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`types` can also be strings:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert risinstance(('str','int'), 'a')\n",
    "assert risinstance('str', 'a')\n",
    "assert not risinstance('int', 'a')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def ver2tuple(v:str)->tuple:\n",
    "    return tuple(int(o or 0) for o in re.search(r'(\\d+)(?:\\.(\\d+))?(?:\\.(\\d+))?', v).groups())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(ver2tuple('3.8.1'), (3,8,1))\n",
    "test_eq(ver2tuple('3.1'), (3,1,0))\n",
    "test_eq(ver2tuple('3.'), (3,0,0))\n",
    "test_eq(ver2tuple('3'), (3,0,0))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## NoOp"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "These are used when you need a pass-through function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "---\n",
       "\n",
       "### noop\n",
       "\n",
       ">      noop (x=None, *args, **kwargs)\n",
       "\n",
       "*Do nothing*"
      ],
      "text/plain": [
       "---\n",
       "\n",
       "### noop\n",
       "\n",
       ">      noop (x=None, *args, **kwargs)\n",
       "\n",
       "*Do nothing*"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "show_doc(noop)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "noop()\n",
    "test_eq(noop(1),1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "---\n",
       "\n",
       "### noops\n",
       "\n",
       ">      noops (x=None, *args, **kwargs)\n",
       "\n",
       "*Do nothing (method)*"
      ],
      "text/plain": [
       "---\n",
       "\n",
       "### noops\n",
       "\n",
       ">      noops (x=None, *args, **kwargs)\n",
       "\n",
       "*Do nothing (method)*"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "show_doc(noops)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class _t: foo=noops\n",
    "test_eq(_t().foo(1),1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Infinite Lists"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "These lists are useful for things like padding an array or adding index column(s) to arrays."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "#|hide\n",
    "class _InfMeta(type):\n",
    "    @property\n",
    "    def count(self): return itertools.count()\n",
    "    @property\n",
    "    def zeros(self): return itertools.cycle([0])\n",
    "    @property\n",
    "    def ones(self):  return itertools.cycle([1])\n",
    "    @property\n",
    "    def nones(self): return itertools.cycle([None])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class Inf(metaclass=_InfMeta):\n",
    "    \"Infinite lists\"\n",
    "    pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(Inf);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`Inf` defines the following properties:\n",
    "    \n",
    "- `count: itertools.count()`\n",
    "- `zeros: itertools.cycle([0])`\n",
    "- `ones : itertools.cycle([1])`\n",
    "- `nones: itertools.cycle([None])`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq([o for i,o in zip(range(5), Inf.count)],\n",
    "        [0, 1, 2, 3, 4])\n",
    "\n",
    "test_eq([o for i,o in zip(range(5), Inf.zeros)],\n",
    "        [0]*5)\n",
    "\n",
    "test_eq([o for i,o in zip(range(5), Inf.ones)],\n",
    "        [1]*5)\n",
    "\n",
    "test_eq([o for i,o in zip(range(5), Inf.nones)],\n",
    "        [None]*5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Operator Functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "_dumobj = object()\n",
    "def _oper(op,a,b=_dumobj): return (lambda o:op(o,a)) if b is _dumobj else op(a,b)\n",
    "\n",
    "def _mk_op(nm, mod):\n",
    "    \"Create an operator using `oper` and add to the caller's module\"\n",
    "    op = getattr(operator,nm)\n",
    "    def _inner(a, b=_dumobj): return _oper(op, a,b)\n",
    "    _inner.__name__ = _inner.__qualname__ = nm\n",
    "    _inner.__doc__ = f'Same as `operator.{nm}`, or returns partial if 1 arg'\n",
    "    mod[nm] = _inner"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def in_(x, a):\n",
    "    \"`True` if `x in a`\"\n",
    "    return x in a\n",
    "\n",
    "operator.in_ = in_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "_all_ = ['lt','gt','le','ge','eq','ne','add','sub','mul','truediv','is_','is_not','in_', 'mod']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "for op in _all_: _mk_op(op, globals())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# test if element is in another\n",
    "assert in_('c', ('b', 'c', 'a'))\n",
    "assert in_(4, [2,3,4,5])\n",
    "assert in_('t', 'fastai')\n",
    "test_fail(in_('h', 'fastai'))\n",
    "\n",
    "# use in_ as a partial\n",
    "assert in_('fastai')('t')\n",
    "assert in_([2,3,4,5])(4)\n",
    "test_fail(in_('fastai')('h'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In addition to `in_`, the following functions are provided matching the behavior of the equivalent versions in `operator`: *lt gt le ge eq ne add sub mul truediv is_ is_not mod*."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(True, False, True, False, 1)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lt(3,5),gt(3,5),is_(None,None),in_(0,[1,2]),mod(3,2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Similarly to `_in`, they also have additional functionality: if you only pass one param, they return a partial function that passes that param as the second positional parameter."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(True, False, True, False, 1)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lt(5)(3),gt(5)(3),is_(None)(None),in_([1,2])(0),mod(2)(3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def ret_true(*args, **kwargs):\n",
    "    \"Predicate: always `True`\"\n",
    "    return True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert ret_true(1,2,3)\n",
    "assert ret_true(False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def ret_false(*args, **kwargs):\n",
    "    \"Predicate: always `False`\"\n",
    "    return False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def stop(e=StopIteration):\n",
    "    \"Raises exception `e` (by default `StopIteration`)\"\n",
    "    raise e"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def gen(func, seq, cond=ret_true):\n",
    "    \"Like `(func(o) for o in seq if cond(func(o)))` but handles `StopIteration`\"\n",
    "    return itertools.takewhile(cond, map(func,seq))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(gen(noop, Inf.count, lt(5)),\n",
    "        range(5))\n",
    "test_eq(gen(operator.neg, Inf.count, gt(-5)),\n",
    "        [0,-1,-2,-3,-4])\n",
    "test_eq(gen(lambda o:o if o<5 else stop(), Inf.count),\n",
    "        range(5))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def chunked(it, chunk_sz=None, drop_last=False, n_chunks=None):\n",
    "    \"Return batches from iterator `it` of size `chunk_sz` (or return `n_chunks` total)\"\n",
    "    assert bool(chunk_sz) ^ bool(n_chunks)\n",
    "    if n_chunks: chunk_sz = max(math.ceil(len(it)/n_chunks), 1)\n",
    "    if not isinstance(it, Iterator): it = iter(it)\n",
    "    while True:\n",
    "        res = list(itertools.islice(it, chunk_sz))\n",
    "        if res and (len(res)==chunk_sz or not drop_last): yield res\n",
    "        if len(res)<chunk_sz: return"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note that you must pass either `chunk_sz`, or `n_chunks`, but not both."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t = list(range(10))\n",
    "test_eq(chunked(t,3),      [[0,1,2], [3,4,5], [6,7,8], [9]])\n",
    "test_eq(chunked(t,3,True), [[0,1,2], [3,4,5], [6,7,8],    ])\n",
    "\n",
    "t = map(lambda o:stop() if o==6 else o, Inf.count)\n",
    "test_eq(chunked(t,3), [[0, 1, 2], [3, 4, 5]])\n",
    "t = map(lambda o:stop() if o==7 else o, Inf.count)\n",
    "test_eq(chunked(t,3), [[0, 1, 2], [3, 4, 5], [6]])\n",
    "\n",
    "t = np.arange(10)\n",
    "test_eq(chunked(t,3),      [[0,1,2], [3,4,5], [6,7,8], [9]])\n",
    "test_eq(chunked(t,3,True), [[0,1,2], [3,4,5], [6,7,8],    ])\n",
    "\n",
    "test_eq(chunked([], 3),          [])\n",
    "test_eq(chunked([], n_chunks=3), [])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def otherwise(x, tst, y):\n",
    "    \"`y if tst(x) else x`\"\n",
    "    return y if tst(x) else x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(otherwise(2+1, gt(3), 4), 3)\n",
    "test_eq(otherwise(2+1, gt(2), 4), 4)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Attribute Helpers"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "These functions reduce boilerplate when setting or manipulating attributes or properties of objects."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def custom_dir(c, add):\n",
    "    \"Implement custom `__dir__`, adding `add` to `cls`\"\n",
    "    return object.__dir__(c) + listify(add)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`custom_dir` allows you extract the [`__dict__` property of a class](https://stackoverflow.com/questions/19907442/explain-dict-attribute) and appends the list `add` to it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class _T: \n",
    "    def f(): pass\n",
    "\n",
    "s = custom_dir(_T(), add=['foo', 'bar'])\n",
    "assert {'foo', 'bar', 'f'}.issubset(s)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class AttrDict(dict):\n",
    "    \"`dict` subclass that also provides access to keys as attrs\"\n",
    "    def __getattr__(self,k): return self[k] if k in self else stop(AttributeError(k))\n",
    "    def __setattr__(self, k, v): (self.__setitem__,super().__setattr__)[k[0]=='_'](k,v)\n",
    "    def __dir__(self): return super().__dir__() + list(self.keys())\n",
    "    def _repr_markdown_(self): return f'```json\\n{pprint.pformat(self, indent=2)}\\n```'\n",
    "    def copy(self): return AttrDict(**self)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "d = AttrDict(a=1,b=\"two\")\n",
    "test_eq(d.a, 1)\n",
    "test_eq(d['b'], 'two')\n",
    "test_eq(d.get('c','nope'), 'nope')\n",
    "d.b = 2\n",
    "test_eq(d.b, 2)\n",
    "test_eq(d['b'], 2)\n",
    "d['b'] = 3\n",
    "test_eq(d['b'], 3)\n",
    "test_eq(d.b, 3)\n",
    "assert 'a' in dir(d)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`AttrDict` will pretty print in Jupyter Notebooks:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "```json\n",
       "{ 'a': 1,\n",
       "  'b': {'c': 1, 'd': 2},\n",
       "  'c': {'c': 1, 'd': 2},\n",
       "  'd': {'c': 1, 'd': 2},\n",
       "  'e': {'c': 1, 'd': 2},\n",
       "  'f': {'c': 1, 'd': 2, 'e': 4, 'f': [1, 2, 3, 4, 5]}}\n",
       "```"
      ],
      "text/plain": [
       "{'a': 1,\n",
       " 'b': {'c': 1, 'd': 2},\n",
       " 'c': {'c': 1, 'd': 2},\n",
       " 'd': {'c': 1, 'd': 2},\n",
       " 'e': {'c': 1, 'd': 2},\n",
       " 'f': {'c': 1, 'd': 2, 'e': 4, 'f': [1, 2, 3, 4, 5]}}"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "_test_dict = {'a':1, 'b': {'c':1, 'd':2}, 'c': {'c':1, 'd':2}, 'd': {'c':1, 'd':2},\n",
    "              'e': {'c':1, 'd':2}, 'f': {'c':1, 'd':2, 'e': 4, 'f':[1,2,3,4,5]}}\n",
    "AttrDict(_test_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class AttrDictDefault(AttrDict):\n",
    "    \"`AttrDict` subclass that returns `None` for missing attrs\"\n",
    "    def __init__(self, *args, default_=None, **kwargs):\n",
    "        self.default_ = default_\n",
    "        super().__init__(*args, **kwargs)\n",
    "\n",
    "    def __getattr__(self,k): return self[k] if k in self else self.default_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "d = AttrDictDefault(a=1,b=\"two\", default_='nope')\n",
    "test_eq(d.a, 1)\n",
    "test_eq(d['b'], 'two')\n",
    "test_eq(d.c, 'nope')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class NS(SimpleNamespace):\n",
    "    \"`SimpleNamespace` subclass that also adds `iter` and `dict` support\"\n",
    "    def __iter__(self): return iter(self.__dict__)\n",
    "    def __getitem__(self,x): return self.__dict__[x]\n",
    "    def __setitem__(self,x,y): self.__dict__[x] = y"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is very similar to `AttrDict`, but since it starts with `SimpleNamespace`, it has some differences in behavior. You can use it just like `SimpleNamespace`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "namespace(a=1,\n",
       "          b={'c': 1, 'd': 2},\n",
       "          c={'c': 1, 'd': 2},\n",
       "          d={'c': 1, 'd': 2},\n",
       "          e={'c': 1, 'd': 2},\n",
       "          f={'c': 1, 'd': 2, 'e': 4, 'f': [1, 2, 3, 4, 5]})"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "d = NS(**_test_dict)\n",
    "d"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "...but you can also index it to get/set:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "d['a']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "...and iterate t:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['a', 'b', 'c', 'd', 'e', 'f']"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "list(d)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def get_annotations_ex(obj, *, globals=None, locals=None):\n",
    "    \"Backport of py3.10 `get_annotations` that returns globals/locals\"\n",
    "    if isinstance(obj, type):\n",
    "        obj_dict = getattr(obj, '__dict__', None)\n",
    "        if obj_dict and hasattr(obj_dict, 'get'):\n",
    "            ann = obj_dict.get('__annotations__', None)\n",
    "            if isinstance(ann, types.GetSetDescriptorType): ann = None\n",
    "        else: ann = None\n",
    "\n",
    "        obj_globals = None\n",
    "        module_name = getattr(obj, '__module__', None)\n",
    "        if module_name:\n",
    "            module = sys.modules.get(module_name, None)\n",
    "            if module: obj_globals = getattr(module, '__dict__', None)\n",
    "        obj_locals = dict(vars(obj))\n",
    "        unwrap = obj\n",
    "    elif isinstance(obj, types.ModuleType):\n",
    "        ann = getattr(obj, '__annotations__', None)\n",
    "        obj_globals = getattr(obj, '__dict__')\n",
    "        obj_locals,unwrap = None,None\n",
    "    elif callable(obj):\n",
    "        ann = getattr(obj, '__annotations__', None)\n",
    "        obj_globals = getattr(obj, '__globals__', None)\n",
    "        obj_locals,unwrap = None,obj\n",
    "    else: raise TypeError(f\"{obj!r} is not a module, class, or callable.\")\n",
    "\n",
    "    if ann is None: ann = {}\n",
    "    if not isinstance(ann, dict): raise ValueError(f\"{obj!r}.__annotations__ is neither a dict nor None\")\n",
    "    if not ann: ann = {}\n",
    "\n",
    "    if unwrap is not None:\n",
    "        while True:\n",
    "            if hasattr(unwrap, '__wrapped__'):\n",
    "                unwrap = unwrap.__wrapped__\n",
    "                continue\n",
    "            if isinstance(unwrap, functools.partial):\n",
    "                unwrap = unwrap.func\n",
    "                continue\n",
    "            break\n",
    "        if hasattr(unwrap, \"__globals__\"): obj_globals = unwrap.__globals__\n",
    "\n",
    "    if globals is None: globals = obj_globals\n",
    "    if locals is None: locals = obj_locals\n",
    "\n",
    "    return dict(ann), globals, locals"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In Python 3.10 `inspect.get_annotations` was added. However previous versions of Python are unable to evaluate type annotations correctly if `from future import __annotations__` is used. Furthermore, *all* annotations are evaluated, even if only some subset are needed. `get_annotations_ex` provides the same functionality as `inspect.get_annotations`, but works on earlier versions of Python, and returns the `globals` and `locals` needed to evaluate types."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def eval_type(t, glb, loc):\n",
    "    \"`eval` a type or collection of types, if needed, for annotations in py3.10+\"\n",
    "    if isinstance(t,str):\n",
    "        if '|' in t: return Union[eval_type(tuple(t.split('|')), glb, loc)]\n",
    "        return eval(t, glb, loc)\n",
    "    if isinstance(t,(tuple,list)): return type(t)([eval_type(c, glb, loc) for c in t])\n",
    "    return t"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In py3.10, or if `from future import __annotations__` is used, `a` is a `str`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "__main__._T2a"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class _T2a: pass\n",
    "def func(a: _T2a): pass\n",
    "ann,glb,loc = get_annotations_ex(func)\n",
    "\n",
    "eval_type(ann['a'], glb, loc)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`|` is supported for defining `Union` types when using `eval_type` even for python versions prior to 3.9:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "typing.Union[__main__._T2a, __main__._T2b]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class _T2b: pass\n",
    "def func(a: _T2a|_T2b): pass\n",
    "ann,glb,loc = get_annotations_ex(func)\n",
    "\n",
    "eval_type(ann['a'], glb, loc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def _eval_type(t, glb, loc):\n",
    "    res = eval_type(t, glb, loc)\n",
    "    return NoneType if res is None else res\n",
    "\n",
    "def type_hints(f):\n",
    "    \"Like `typing.get_type_hints` but returns `{}` if not allowed type\"\n",
    "    if not isinstance(f, typing._allowed_types): return {}\n",
    "    ann,glb,loc = get_annotations_ex(f)\n",
    "    return {k:_eval_type(v,glb,loc) for k,v in ann.items()}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Below is a list of allowed types for type hints in python:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[function,\n",
       " builtin_function_or_method,\n",
       " method,\n",
       " module,\n",
       " wrapper_descriptor,\n",
       " method-wrapper,\n",
       " method_descriptor]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "list(typing._allowed_types)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For example, type `func` is allowed so `type_hints` returns the same value as `typing.get_hints`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def f(a:int)->bool: ... # a function with type hints (allowed)\n",
    "exp = {'a':int,'return':bool}\n",
    "test_eq(type_hints(f), typing.get_type_hints(f))\n",
    "test_eq(type_hints(f), exp)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "However, `class` is not an allowed type, so `type_hints` returns `{}`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class _T:\n",
    "    def __init__(self, a:int=0)->bool: ...\n",
    "assert not type_hints(_T)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def annotations(o):\n",
    "    \"Annotations for `o`, or `type(o)`\"\n",
    "    res = {}\n",
    "    if not o: return res\n",
    "    res = type_hints(o)\n",
    "    if not res: res = type_hints(getattr(o,'__init__',None))\n",
    "    if not res: res = type_hints(type(o))\n",
    "    return res"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This supports a wider range of situations than `type_hints`, by checking `type()` and `__init__` for annotations too:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for o in _T,_T(),_T.__init__,f: test_eq(annotations(o), exp)\n",
    "assert not annotations(int)\n",
    "assert not annotations(print)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def anno_ret(func):\n",
    "    \"Get the return annotation of `func`\"\n",
    "    return annotations(func).get('return', None) if func else None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def f(x) -> float: return x\n",
    "test_eq(anno_ret(f), float)\n",
    "\n",
    "def f(x) -> typing.Tuple[float,float]: return x\n",
    "assert anno_ret(f)==typing.Tuple[float,float]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If your return annotation is `None`, `anno_ret` will return `NoneType` (and not `None`):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def f(x) -> None: return x\n",
    "\n",
    "test_eq(anno_ret(f), NoneType)\n",
    "assert anno_ret(f) is not None # returns NoneType instead of None"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If your function does not have a return type, or if you pass in `None` instead of a function, then `anno_ret` returns `None`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def f(x): return x\n",
    "\n",
    "test_eq(anno_ret(f), None)\n",
    "test_eq(anno_ret(None), None) # instead of passing in a func, pass in None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def _ispy3_10(): return sys.version_info.major >=3 and sys.version_info.minor >=10\n",
    "\n",
    "def signature_ex(obj, eval_str:bool=False):\n",
    "    \"Backport of `inspect.signature(..., eval_str=True` to <py310\"\n",
    "    from inspect import Signature, Parameter, signature\n",
    "\n",
    "    def _eval_param(ann, k, v):\n",
    "        if k not in ann: return v\n",
    "        return Parameter(v.name, v.kind, annotation=ann[k], default=v.default)\n",
    "\n",
    "    if not eval_str: return signature(obj)\n",
    "    if _ispy3_10(): return signature(obj, eval_str=eval_str)\n",
    "    sig = signature(obj)\n",
    "    if sig is None: return None\n",
    "    ann = type_hints(obj)\n",
    "    params = [_eval_param(ann,k,v) for k,v in sig.parameters.items()]\n",
    "    return Signature(params, return_annotation=sig.return_annotation)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def union2tuple(t):\n",
    "    if (getattr(t, '__origin__', None) is Union\n",
    "        or (UnionType and isinstance(t, UnionType))): return t.__args__\n",
    "    return t"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(union2tuple(Union[int,str]), (int,str))\n",
    "test_eq(union2tuple(int), int)\n",
    "assert union2tuple(Tuple[int,str])==Tuple[int,str]\n",
    "test_eq(union2tuple((int,str)), (int,str))\n",
    "if UnionType: test_eq(union2tuple(int|str), (int,str))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def argnames(f, frame=False):\n",
    "    \"Names of arguments to function or frame `f`\"\n",
    "    code = getattr(f, 'f_code' if frame else '__code__')\n",
    "    return code.co_varnames[:code.co_argcount+code.co_kwonlyargcount]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(argnames(f), ['x'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def with_cast(f):\n",
    "    \"Decorator which uses any parameter annotations as preprocessing functions\"\n",
    "    anno, out_anno, params = annotations(f), anno_ret(f), argnames(f)\n",
    "    c_out = ifnone(out_anno, noop)\n",
    "    defaults = dict(zip(reversed(params), reversed(f.__defaults__ or {})))\n",
    "    @functools.wraps(f)\n",
    "    def _inner(*args, **kwargs):\n",
    "        args = list(args)\n",
    "        for i,v in enumerate(params):\n",
    "            if v in anno:\n",
    "                c = anno[v]\n",
    "                if v in kwargs: kwargs[v] = c(kwargs[v])\n",
    "                elif i<len(args): args[i] = c(args[i])\n",
    "                elif v in defaults: kwargs[v] = c(defaults[v])\n",
    "        return c_out(f(*args, **kwargs))\n",
    "    return _inner"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@with_cast\n",
    "def _f(a, b:Path, c:str='', d=0): return (a,b,c,d)\n",
    "\n",
    "test_eq(_f(1, '.', 3), (1,Path('.'),'3',0))\n",
    "test_eq(_f(1, '.'), (1,Path('.'),'',0))\n",
    "\n",
    "@with_cast\n",
    "def _g(a:int=0)->str: return a\n",
    "\n",
    "test_eq(_g(4.0), '4')\n",
    "test_eq(_g(4.4), '4')\n",
    "test_eq(_g(2), '2')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def _store_attr(self, anno, **attrs):\n",
    "    stored = getattr(self, '__stored_args__', None)\n",
    "    for n,v in attrs.items():\n",
    "        if n in anno: v = anno[n](v)\n",
    "        setattr(self, n, v)\n",
    "        if stored is not None: stored[n] = v"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def store_attr(names=None, self=None, but='', cast=False, store_args=None, **attrs):\n",
    "    \"Store params named in comma-separated `names` from calling context into attrs in `self`\"\n",
    "    fr = sys._getframe(1)\n",
    "    args = argnames(fr, True)\n",
    "    if self: args = ('self', *args)\n",
    "    else: self = fr.f_locals[args[0]]\n",
    "    if store_args is None: store_args = not hasattr(self,'__slots__')\n",
    "    if store_args and not hasattr(self, '__stored_args__'): self.__stored_args__ = {}\n",
    "    anno = annotations(self) if cast else {}\n",
    "    if names and isinstance(names,str): names = re.split(', *', names)\n",
    "    ns = names if names is not None else getattr(self, '__slots__', args[1:])\n",
    "    added = {n:fr.f_locals[n] for n in ns}\n",
    "    attrs = {**attrs, **added}\n",
    "    if isinstance(but,str): but = re.split(', *', but)\n",
    "    attrs = {k:v for k,v in attrs.items() if k not in but}\n",
    "    return _store_attr(self, anno, **attrs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In it's most basic form, you can use `store_attr` to shorten code like this:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class T:\n",
    "    def __init__(self, a,b,c): self.a,self.b,self.c = a,b,c"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "...to this:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class T:\n",
    "    def __init__(self, a,b,c): store_attr('a,b,c', self)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This class behaves as if we'd used the first form:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t = T(1,c=2,b=3)\n",
    "assert t.a==1 and t.b==3 and t.c==2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In addition, it stores the attrs as a `dict` in `__stored_args__`, which you can use for display, logging, and so forth."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(t.__stored_args__, {'a':1, 'b':3, 'c':2})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Since you normally want to use the first argument (often called `self`) for storing attributes, it's optional:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class T:\n",
    "    def __init__(self, a,b,c:str): store_attr('a,b,c')\n",
    "\n",
    "t = T(1,c=2,b=3)\n",
    "assert t.a==1 and t.b==3 and t.c==2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|hide\n",
    "class _T:\n",
    "    def __init__(self, a,b):\n",
    "        c = 2\n",
    "        store_attr('a,b,c')\n",
    "\n",
    "t = _T(1,b=3)\n",
    "assert t.a==1 and t.b==3 and t.c==2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With `cast=True` any parameter annotations will be used as preprocessing functions for the corresponding arguments:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class T:\n",
    "    def __init__(self, a:listify, b, c:str): store_attr('a,b,c', cast=True)\n",
    "\n",
    "t = T(1,c=2,b=3)\n",
    "assert t.a==[1] and t.b==3 and t.c=='2'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can inherit from a class using `store_attr`, and just call it again to add in any new attributes added in the derived class:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class T2(T):\n",
    "    def __init__(self, d, **kwargs):\n",
    "        super().__init__(**kwargs)\n",
    "        store_attr('d')\n",
    "\n",
    "t = T2(d=1,a=2,b=3,c=4)\n",
    "assert t.a==2 and t.b==3 and t.c==4 and t.d==1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can skip passing a list of attrs to store. In this case, all arguments passed to the method are stored:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class T:\n",
    "    def __init__(self, a,b,c): store_attr()\n",
    "\n",
    "t = T(1,c=2,b=3)\n",
    "assert t.a==1 and t.b==3 and t.c==2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class T4(T):\n",
    "    def __init__(self, d, **kwargs):\n",
    "        super().__init__(**kwargs)\n",
    "        store_attr()\n",
    "\n",
    "t = T4(4, a=1,c=2,b=3)\n",
    "assert t.a==1 and t.b==3 and t.c==2 and t.d==4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class T4:\n",
    "    def __init__(self, *, a: int, b: float = 1):\n",
    "        store_attr()\n",
    "        \n",
    "t = T4(a=3)\n",
    "assert t.a==3 and t.b==1\n",
    "t = T4(a=3, b=2)\n",
    "assert t.a==3 and t.b==2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|hide\n",
    "# ensure that subclasses work with or without `store_attr`\n",
    "class T4(T):\n",
    "    def __init__(self, **kwargs):\n",
    "        super().__init__(**kwargs)\n",
    "        store_attr()\n",
    "\n",
    "t = T4(a=1,c=2,b=3)\n",
    "assert t.a==1 and t.b==3 and t.c==2\n",
    "\n",
    "class T4(T): pass\n",
    "\n",
    "t = T4(a=1,c=2,b=3)\n",
    "assert t.a==1 and t.b==3 and t.c==2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|hide\n",
    "#ensure that kwargs work with names==None\n",
    "class T:\n",
    "    def __init__(self, a,b,c,**kwargs): store_attr(**kwargs)\n",
    "\n",
    "t = T(1,c=2,b=3,d=4,e=-1)\n",
    "assert t.a==1 and t.b==3 and t.c==2 and t.d==4 and t.e==-1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|hide\n",
    "#ensure that kwargs work with names==''\n",
    "class T:\n",
    "    def __init__(self, a, **kwargs):\n",
    "        self.a = a+1\n",
    "        store_attr('', **kwargs)\n",
    "\n",
    "t = T(a=1, d=4)\n",
    "test_eq(t.a, 2)\n",
    "test_eq(t.d, 4)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can skip some attrs by passing `but`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class T:\n",
    "    def __init__(self, a,b,c): store_attr(but='a')\n",
    "\n",
    "t = T(1,c=2,b=3)\n",
    "assert t.b==3 and t.c==2\n",
    "assert not hasattr(t,'a')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can also pass keywords to `store_attr`, which is identical to setting the attrs directly, but also stores them in `__stored_args__`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class T:\n",
    "    def __init__(self): store_attr(a=1)\n",
    "\n",
    "t = T()\n",
    "assert t.a==1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can also use store_attr inside functions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_T(a, b):\n",
    "    t = SimpleNamespace()\n",
    "    store_attr(self=t)\n",
    "    return t\n",
    "\n",
    "t = create_T(a=1, b=2)\n",
    "assert t.a==1 and t.b==2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def attrdict(o, *ks, default=None):\n",
    "    \"Dict from each `k` in `ks` to `getattr(o,k)`\"\n",
    "    return {k:getattr(o, k, default) for k in ks}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class T:\n",
    "    def __init__(self, a,b,c): store_attr()\n",
    "\n",
    "t = T(1,c=2,b=3)\n",
    "test_eq(attrdict(t,'b','c'), {'b':3, 'c':2})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def properties(cls, *ps):\n",
    "    \"Change attrs in `cls` with names in `ps` to properties\"\n",
    "    for p in ps: setattr(cls,p,property(getattr(cls,p)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class T:\n",
    "    def a(self): return 1\n",
    "    def b(self): return 2\n",
    "properties(T,'a')\n",
    "\n",
    "test_eq(T().a,1)\n",
    "test_eq(T().b(),2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "_c2w_re = re.compile(r'((?<=[a-z])[A-Z]|(?<!\\A)[A-Z](?=[a-z]))')\n",
    "_camel_re1 = re.compile('(.)([A-Z][a-z]+)')\n",
    "_camel_re2 = re.compile('([a-z0-9])([A-Z])')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def camel2words(s, space=' '):\n",
    "    \"Convert CamelCase to 'spaced words'\"\n",
    "    return re.sub(_c2w_re, rf'{space}\\1', s)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(camel2words('ClassAreCamel'), 'Class Are Camel')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def camel2snake(name):\n",
    "    \"Convert CamelCase to snake_case\"\n",
    "    s1   = re.sub(_camel_re1, r'\\1_\\2', name)\n",
    "    return re.sub(_camel_re2, r'\\1_\\2', s1).lower()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(camel2snake('ClassAreCamel'), 'class_are_camel')\n",
    "test_eq(camel2snake('Already_Snake'), 'already__snake')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def snake2camel(s):\n",
    "    \"Convert snake_case to CamelCase\"\n",
    "    return ''.join(s.title().split('_'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(snake2camel('a_b_cc'), 'ABCc')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def class2attr(self, cls_name):\n",
    "    \"Return the snake-cased name of the class; strip ending `cls_name` if it exists.\"\n",
    "    return camel2snake(re.sub(rf'{cls_name}$', '', self.__class__.__name__) or cls_name.lower())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Parent:\n",
    "    @property\n",
    "    def name(self): return class2attr(self, 'Parent')\n",
    "\n",
    "class ChildOfParent(Parent): pass\n",
    "class ParentChildOf(Parent): pass\n",
    "\n",
    "p = Parent()\n",
    "cp = ChildOfParent()\n",
    "cp2 = ParentChildOf()\n",
    "\n",
    "test_eq(p.name, 'parent')\n",
    "test_eq(cp.name, 'child_of')\n",
    "test_eq(cp2.name, 'parent_child_of')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def getcallable(o, attr):\n",
    "    \"Calls `getattr` with a default of `noop`\"\n",
    "    return getattr(o, attr, noop)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Math:\n",
    "    def addition(self,a,b): return a+b\n",
    "\n",
    "m = Math()\n",
    "\n",
    "test_eq(getcallable(m, \"addition\")(a=1,b=2), 3)\n",
    "test_eq(getcallable(m, \"subtraction\")(a=1,b=2), None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def getattrs(o, *attrs, default=None):\n",
    "    \"List of all `attrs` in `o`\"\n",
    "    return [getattr(o,attr,default) for attr in attrs]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from fractions import Fraction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[1, 2]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "getattrs(Fraction(1,2), 'numerator', 'denominator')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def hasattrs(o,attrs):\n",
    "    \"Test whether `o` contains all `attrs`\"\n",
    "    return all(hasattr(o,attr) for attr in attrs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert hasattrs(1,('imag','real'))\n",
    "assert not hasattrs(1,('imag','foo'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def setattrs(dest, flds, src):\n",
    "    f = dict.get if isinstance(src, dict) else getattr\n",
    "    flds = re.split(r\",\\s*\", flds)\n",
    "    for fld in flds: setattr(dest, fld, f(src, fld))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "d = dict(a=1,bb=\"2\",ignore=3)\n",
    "o = SimpleNamespace()\n",
    "setattrs(o, \"a,bb\", d)\n",
    "test_eq(o.a, 1)\n",
    "test_eq(o.bb, \"2\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "d = SimpleNamespace(a=1,bb=\"2\",ignore=3)\n",
    "o = SimpleNamespace()\n",
    "setattrs(o, \"a,bb\", d)\n",
    "test_eq(o.a, 1)\n",
    "test_eq(o.bb, \"2\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def try_attrs(obj, *attrs):\n",
    "    \"Return first attr that exists in `obj`\"\n",
    "    for att in attrs:\n",
    "        try: return getattr(obj, att)\n",
    "        except: pass\n",
    "    raise AttributeError(attrs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(try_attrs(1, 'real'), 1)\n",
    "test_eq(try_attrs(1, 'foobar', 'real'), 1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Attribute Delegation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class GetAttrBase:\n",
    "    \"Basic delegation of `__getattr__` and `__dir__`\"\n",
    "    _attr=noop\n",
    "    def __getattr__(self,k):\n",
    "        if k[0]=='_' or k==self._attr: return super().__getattr__(k)\n",
    "        return self._getattr(getattr(self, self._attr)[k])\n",
    "    def __dir__(self): return custom_dir(self, getattr(self, self._attr))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class GetAttr:\n",
    "    \"Inherit from this to have all attr accesses in `self._xtra` passed down to `self.default`\"\n",
    "    _default='default'\n",
    "    def _component_attr_filter(self,k):\n",
    "        if k.startswith('__') or k in ('_xtra',self._default): return False\n",
    "        xtra = getattr(self,'_xtra',None)\n",
    "        return xtra is None or k in xtra\n",
    "    def _dir(self): return [k for k in dir(getattr(self,self._default)) if self._component_attr_filter(k)]\n",
    "    def __getattr__(self,k):\n",
    "        if self._component_attr_filter(k):\n",
    "            attr = getattr(self,self._default,None)\n",
    "            if attr is not None: return getattr(attr,k)\n",
    "        raise AttributeError(k)\n",
    "    def __dir__(self): return custom_dir(self,self._dir())\n",
    "#     def __getstate__(self): return self.__dict__\n",
    "    def __setstate__(self,data): self.__dict__.update(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "---\n",
       "\n",
       "[source](https://github.com/fastai/fastcore/blob/master/fastcore/basics.py#L501){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
       "\n",
       "#### GetAttr\n",
       "\n",
       ">      GetAttr ()\n",
       "\n",
       "*Inherit from this to have all attr accesses in `self._xtra` passed down to `self.default`*"
      ],
      "text/plain": [
       "---\n",
       "\n",
       "[source](https://github.com/fastai/fastcore/blob/master/fastcore/basics.py#L501){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
       "\n",
       "#### GetAttr\n",
       "\n",
       ">      GetAttr ()\n",
       "\n",
       "*Inherit from this to have all attr accesses in `self._xtra` passed down to `self.default`*"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "show_doc(GetAttr, title_level=4)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Inherit from `GetAttr` to have attr access passed down to an instance attribute. \n",
    "This makes it easy to create composites that don't require callers to know about their components.  For a more detailed discussion of how this works as well as relevant context, we suggest reading the [delegated composition section of this blog article](https://www.fast.ai/2019/08/06/delegation/).\n",
    "\n",
    "You can customise the behaviour of `GetAttr` in subclasses via;\n",
    "- `_default`\n",
    "    - By default, this is set to `'default'`, so attr access is passed down to `self.default`\n",
    "    - `_default` can be set to the name of any instance attribute that does not start with dunder `__`\n",
    "- `_xtra`\n",
    "    - By default, this is `None`, so all attr access is passed down\n",
    "    - You can limit which attrs get passed down by setting `_xtra` to a list of attribute names"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To illuminate the utility of `GetAttr`, suppose we have the following two classes, `_WebPage` which is a superclass of `_ProductPage`, which we wish to compose like so:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class _WebPage:\n",
    "    def __init__(self, title, author=\"Jeremy\"):\n",
    "        self.title,self.author = title,author\n",
    "\n",
    "class _ProductPage:\n",
    "    def __init__(self, page, price): self.page,self.price = page,price\n",
    "        \n",
    "page = _WebPage('Soap', author=\"Sylvain\")\n",
    "p = _ProductPage(page, 15.0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "How do we make it so we can just write `p.author`, instead of `p.page.author` to access the `author` attribute?  We can use `GetAttr`, of course!  First, we subclass `GetAttr` when defining `_ProductPage`.  Next, we set `self.default` to the object whose attributes we want to be able to access directly, which in this case is the `page` argument passed on initialization:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class _ProductPage(GetAttr):\n",
    "    def __init__(self, page, price): self.default,self.price = page,price #self.default allows you to access page directly.\n",
    "\n",
    "p = _ProductPage(page, 15.0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, we can access the `author` attribute directly from the instance:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(p.author, 'Sylvain')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If you wish to store the object you are composing in an attribute other than `self.default`, you can set the class attribute `_data` as shown below.  This is useful in the case where you might have a name collision with `self.default`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class _C(GetAttr):\n",
    "    _default = '_data' # use different component name; `self._data` rather than `self.default`\n",
    "    def __init__(self,a): self._data = a\n",
    "    def foo(self): noop\n",
    "\n",
    "t = _C('Hi')\n",
    "test_eq(t._data, 'Hi') \n",
    "test_fail(lambda: t.default) # we no longer have self.default\n",
    "test_eq(t.lower(), 'hi')\n",
    "test_eq(t.upper(), 'HI')\n",
    "assert 'lower' in dir(t)\n",
    "assert 'upper' in dir(t)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "By default, all attributes and methods of the object you are composing are retained. In the below example, we compose a `str` object with the class `_C`. This allows us to directly call string methods on instances of class `_C`, such as `str.lower()` or `str.upper()`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class _C(GetAttr):\n",
    "    # allow all attributes and methods to get passed to `self.default` (by leaving _xtra=None)\n",
    "    def __init__(self,a): self.default = a\n",
    "    def foo(self): noop\n",
    "\n",
    "t = _C('Hi')\n",
    "test_eq(t.lower(), 'hi')\n",
    "test_eq(t.upper(), 'HI')\n",
    "assert 'lower' in dir(t)\n",
    "assert 'upper' in dir(t)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "However, you can choose which attributes or methods to retain by defining a class attribute `_xtra`, which is a list of allowed attribute and method names to delegate.  In the below example, we only delegate the `lower` method from the composed `str` object when defining class `_C`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class _C(GetAttr):\n",
    "    _xtra = ['lower'] # specify which attributes get passed to `self.default`\n",
    "    def __init__(self,a): self.default = a\n",
    "    def foo(self): noop\n",
    "\n",
    "t = _C('Hi')\n",
    "test_eq(t.default, 'Hi')\n",
    "test_eq(t.lower(), 'hi')\n",
    "test_fail(lambda: t.upper()) # upper wasn't in _xtra, so it isn't available to be called\n",
    "assert 'lower' in dir(t)\n",
    "assert 'upper' not in dir(t)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You must be careful to properly set an instance attribute in `__init__` that corresponds to the class attribute `_default`.  The below example sets the class attribute `_default` to `data`, but erroneously fails to define `self.data` (and instead defines `self.default`).\n",
    "\n",
    "Failing to properly set instance attributes leads to errors when you try to access methods directly:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class _C(GetAttr):\n",
    "    _default = 'data' # use a bad component name; i.e. self.data does not exist\n",
    "    def __init__(self,a): self.default = a\n",
    "    def foo(self): noop\n",
    "        \n",
    "# TODO: should we raise an error when we create a new instance ...\n",
    "t = _C('Hi')\n",
    "test_eq(t.default, 'Hi')\n",
    "# ... or is it enough for all GetAttr features to raise errors\n",
    "test_fail(lambda: t.data)\n",
    "test_fail(lambda: t.lower())\n",
    "test_fail(lambda: t.upper())\n",
    "test_fail(lambda: dir(t))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|hide\n",
    "# I don't think this test is essential to the docs but it probably makes sense to\n",
    "# check that everything works when we set both _xtra and _default to non-default values\n",
    "class _C(GetAttr):\n",
    "    _xtra = ['lower', 'upper']\n",
    "    _default = 'data'\n",
    "    def __init__(self,a): self.data = a\n",
    "    def foo(self): noop\n",
    "\n",
    "t = _C('Hi')\n",
    "test_eq(t.data, 'Hi')\n",
    "test_eq(t.lower(), 'hi')\n",
    "test_eq(t.upper(), 'HI')\n",
    "assert 'lower' in dir(t)\n",
    "assert 'upper' in dir(t)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|hide\n",
    "#  when consolidating the filter logic, I choose the previous logic from \n",
    "# __getattr__  k.startswith('__') rather than\n",
    "# _dir         k.startswith('_'). \n",
    "class _C(GetAttr):\n",
    "    def __init__(self): self.default = type('_D', (), {'_under': 1, '__dunder': 2})() \n",
    "    \n",
    "t = _C()\n",
    "test_eq(t.default._under, 1)\n",
    "test_eq(t._under, 1)           # _ prefix attr access is allowed on component\n",
    "assert '_under' in dir(t)\n",
    "\n",
    "test_eq(t.default.__dunder, 2)\n",
    "test_fail(lambda: t.__dunder)  # __ prefix attr access is not allowed on component\n",
    "assert '__dunder' not in dir(t)\n",
    "\n",
    "assert t.__dir__ is not None   # __ prefix attr access is allowed on composite\n",
    "assert '__dir__' in dir(t)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|hide\n",
    "#Failing test. TODO: make GetAttr pickle-safe\n",
    "\n",
    "# class B:\n",
    "#     def __init__(self): self.a = A()\n",
    "\n",
    "# @funcs_kwargs\n",
    "# class A(GetAttr):\n",
    "#     wif=after_iter= noops\n",
    "#     _methods = 'wif after_iter'.split()\n",
    "#     _default = 'dataset'\n",
    "#     def __init__(self, **kwargs): pass\n",
    "    \n",
    "# a = A()\n",
    "# b = A(wif=a.wif)\n",
    "        \n",
    "# a = A()\n",
    "# b = A(wif=a.wif)\n",
    "# tst = pickle.dumps(b)\n",
    "# c = pickle.loads(tst)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def delegate_attr(self, k, to):\n",
    "    \"Use in `__getattr__` to delegate to attr `to` without inheriting from `GetAttr`\"\n",
    "    if k.startswith('_') or k==to: raise AttributeError(k)\n",
    "    try: return getattr(getattr(self,to), k)\n",
    "    except AttributeError: raise AttributeError(k) from None"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`delegate_attr` is a functional way to delegate attributes, and is an alternative to `GetAttr`.  We recommend reading the documentation of `GetAttr` for more details around delegation.\n",
    "\n",
    "You can use achieve delegation when you define `__getattr__` by using `delegate_attr`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|hide\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class _C:\n",
    "    def __init__(self, o): self.o = o # self.o corresponds to the `to` argument in delegate_attr.\n",
    "    def __getattr__(self, k): return delegate_attr(self, k, to='o')\n",
    "    \n",
    "\n",
    "t = _C('HELLO') # delegates to a string\n",
    "test_eq(t.lower(), 'hello')\n",
    "\n",
    "t = _C(np.array([5,4,3])) # delegates to a numpy array\n",
    "test_eq(t.sum(), 12)\n",
    "\n",
    "t = _C(pd.DataFrame({'a': [1,2], 'b': [3,4]})) # delegates to a pandas.DataFrame\n",
    "test_eq(t.b.max(), 4)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Extensible Types"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`ShowPrint` is a base class that defines a `show` method, which is used primarily for callbacks in fastai that expect this method to be defined."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "#|hide\n",
    "class ShowPrint:\n",
    "    \"Base class that prints for `show`\"\n",
    "    def show(self, *args, **kwargs): print(str(self))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`Int`, `Float`, and `Str` extend `int`, `float` and `str` respectively by adding an additional `show` method by inheriting from `ShowPrint`.\n",
    "\n",
    "The code for `Int` is shown below:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "#|hide\n",
    "class Int(int,ShowPrint):\n",
    "    \"An extensible `int`\"\n",
    "    pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export \n",
    "#|hide\n",
    "class Str(str,ShowPrint):\n",
    "    \"An extensible `str`\"\n",
    "    pass\n",
    "class Float(float,ShowPrint):\n",
    "    \"An extensible `float`\"\n",
    "    pass"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Examples:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n",
      "2.0\n",
      "Hello\n"
     ]
    }
   ],
   "source": [
    "Int(0).show()\n",
    "Float(2.0).show()\n",
    "Str('Hello').show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Collection functions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Functions that manipulate popular python collections."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def partition(coll, f):\n",
    "    \"Partition a collection by a predicate\"\n",
    "    ts,fs = [],[]\n",
    "    for o in coll: (fs,ts)[f(o)].append(o)\n",
    "    if isinstance(coll,tuple):\n",
    "        typ = type(coll)\n",
    "        ts,fs = typ(ts),typ(fs)\n",
    "    return ts,fs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ts,fs = partition(range(10), mod(2))\n",
    "test_eq(fs, [0,2,4,6,8])\n",
    "test_eq(ts, [1,3,5,7,9])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def flatten(o):\n",
    "    \"Concatenate all collections and items as a generator\"\n",
    "    for item in o:\n",
    "        if isinstance(item, str): yield item; continue\n",
    "        try: yield from flatten(item)\n",
    "        except TypeError: yield item"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def concat(colls)->list:\n",
    "    \"Concatenate all collections and items as a list\"\n",
    "    return list(flatten(colls))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0, 1, 2, 3, 4, 5]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "concat([(o for o in range(2)),[2,3,4], 5])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['abc', 'xyz', 'foo', 'bar']"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "concat([[\"abc\", \"xyz\"], [\"foo\", \"bar\"]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def strcat(its, sep:str='')->str:\n",
    "    \"Concatenate stringified items `its`\"\n",
    "    return sep.join(map(str,its))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(strcat(['a',2]), 'a2')\n",
    "test_eq(strcat(['a',2], ';'), 'a;2')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def detuplify(x):\n",
    "    \"If `x` is a tuple with one thing, extract it\"\n",
    "    return None if len(x)==0 else x[0] if len(x)==1 and getattr(x, 'ndim', 1)==1 else x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(detuplify(()),None)\n",
    "test_eq(detuplify([1]),1)\n",
    "test_eq(detuplify([1,2]), [1,2])\n",
    "test_eq(detuplify(np.array([[1,2]])), np.array([[1,2]]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def replicate(item,match):\n",
    "    \"Create tuple of `item` copied `len(match)` times\"\n",
    "    return (item,)*len(match)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t = [1,1]\n",
    "test_eq(replicate([1,2], t),([1,2],[1,2]))\n",
    "test_eq(replicate(1, t),(1,1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def setify(o):\n",
    "    \"Turn any list like-object into a set.\"\n",
    "    return o if isinstance(o,set) else set(listify(o))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# test\n",
    "test_eq(setify(None),set())\n",
    "test_eq(setify('abc'),{'abc'})\n",
    "test_eq(setify([1,2,2]),{1,2})\n",
    "test_eq(setify(range(0,3)),{0,1,2})\n",
    "test_eq(setify({1,2}),{1,2})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def merge(*ds):\n",
    "    \"Merge all dictionaries in `ds`\"\n",
    "    return {k:v for d in ds if d is not None for k,v in d.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(merge(), {})\n",
    "test_eq(merge(dict(a=1,b=2)), dict(a=1,b=2))\n",
    "test_eq(merge(dict(a=1,b=2), dict(b=3,c=4), None), dict(a=1, b=3, c=4))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def range_of(x):\n",
    "    \"All indices of collection `x` (i.e. `list(range(len(x)))`)\"\n",
    "    return list(range(len(x)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(range_of([1,1,1,1]), [0,1,2,3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def groupby(x, key, val=noop):\n",
    "    \"Like `itertools.groupby` but doesn't need to be sorted, and isn't lazy, plus some extensions\"\n",
    "    if   isinstance(key,int): key = itemgetter(key)\n",
    "    elif isinstance(key,str): key = attrgetter(key)\n",
    "    if   isinstance(val,int): val = itemgetter(val)\n",
    "    elif isinstance(val,str): val = attrgetter(val)\n",
    "    res = {}\n",
    "    for o in x: res.setdefault(key(o), []).append(val(o))\n",
    "    return res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(groupby('aa ab bb'.split(), itemgetter(0)), {'a':['aa','ab'], 'b':['bb']})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here's an example of how to *invert* a grouping, using an `int` as `key` (which uses `itemgetter`; passing a `str` will use `attrgetter`), and using a `val` function:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{1: [0], 3: [0, 2], 7: [0], 5: [3, 7], 8: [4], 4: [5]}"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "d = {0: [1, 3, 7], 2: [3], 3: [5], 4: [8], 5: [4], 7: [5]}\n",
    "groupby(((o,k) for k,v in d.items() for o in v), 0, 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def last_index(x, o):\n",
    "    \"Finds the last index of occurence of `x` in `o` (returns -1 if no occurence)\"\n",
    "    try: return next(i for i in reversed(range(len(o))) if o[i] == x)\n",
    "    except StopIteration: return -1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(last_index(9, [1, 2, 9, 3, 4, 9, 10]), 5)\n",
    "test_eq(last_index(6, [1, 2, 9, 3, 4, 9, 10]), -1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def filter_dict(d, func):\n",
    "    \"Filter a `dict` using `func`, applied to keys and values\"\n",
    "    return {k:v for k,v in d.items() if func(k,v)}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{65: 'A', 66: 'B', 67: 'C', 68: 'D', 69: 'E', 70: 'F', 71: 'G', 72: 'H'}"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "letters = {o:chr(o) for o in range(65,73)}\n",
    "letters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{65: 'A', 66: 'B', 70: 'F', 71: 'G'}"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "filter_dict(letters, lambda k,v: k<67 or v in 'FG')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def filter_keys(d, func):\n",
    "    \"Filter a `dict` using `func`, applied to keys\"\n",
    "    return {k:v for k,v in d.items() if func(k)}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{65: 'A', 66: 'B'}"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "filter_keys(letters, lt(67))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def filter_values(d, func):\n",
    "    \"Filter a `dict` using `func`, applied to values\"\n",
    "    return {k:v for k,v in d.items() if func(v)}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{70: 'F', 71: 'G'}"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "filter_values(letters, in_('FG'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def cycle(o):\n",
    "    \"Like `itertools.cycle` except creates list of `None`s if `o` is empty\"\n",
    "    o = listify(o)\n",
    "    return itertools.cycle(o) if o is not None and len(o) > 0 else itertools.cycle([None])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(itertools.islice(cycle([1,2,3]),5), [1,2,3,1,2])\n",
    "test_eq(itertools.islice(cycle([]),3), [None]*3)\n",
    "test_eq(itertools.islice(cycle(None),3), [None]*3)\n",
    "test_eq(itertools.islice(cycle(1),3), [1,1,1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def zip_cycle(x, *args):\n",
    "    \"Like `itertools.zip_longest` but `cycle`s through elements of all but first argument\"\n",
    "    return zip(x, *map(cycle,args))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(zip_cycle([1,2,3,4],list('abc')), [(1, 'a'), (2, 'b'), (3, 'c'), (4, 'a')])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def sorted_ex(iterable, key=None, reverse=False):\n",
    "    \"Like `sorted`, but if key is str use `attrgetter`; if int use `itemgetter`\"\n",
    "    if isinstance(key,str):   k=lambda o:getattr(o,key,0)\n",
    "    elif isinstance(key,int): k=itemgetter(key)\n",
    "    else: k=key\n",
    "    return sorted(iterable, key=k, reverse=reverse)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def not_(f):\n",
    "    \"Create new function that negates result of `f`\"\n",
    "    def _f(*args, **kwargs): return not f(*args, **kwargs)\n",
    "    return _f"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def f(a): return a>0\n",
    "test_eq(f(1),True)\n",
    "test_eq(not_(f)(1),False)\n",
    "test_eq(not_(f)(a=-1),True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def argwhere(iterable, f, negate=False, **kwargs):\n",
    "    \"Like `filter_ex`, but return indices for matching items\"\n",
    "    if kwargs: f = partial(f,**kwargs)\n",
    "    if negate: f = not_(f)\n",
    "    return [i for i,o in enumerate(iterable) if f(o)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def filter_ex(iterable, f=noop, negate=False, gen=False, **kwargs):\n",
    "    \"Like `filter`, but passing `kwargs` to `f`, defaulting `f` to `noop`, and adding `negate` and `gen`\"\n",
    "    if f is None: f = lambda _: True\n",
    "    if kwargs: f = partial(f,**kwargs)\n",
    "    if negate: f = not_(f)\n",
    "    res = filter(f, iterable)\n",
    "    if gen: return res\n",
    "    return list(res)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def range_of(a, b=None, step=None):\n",
    "    \"All indices of collection `a`, if `a` is a collection, otherwise `range`\"\n",
    "    if is_coll(a): a = len(a)\n",
    "    return list(range(a,b,step) if step is not None else range(a,b) if b is not None else range(a))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(range_of([1,1,1,1]), [0,1,2,3])\n",
    "test_eq(range_of(4), [0,1,2,3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def renumerate(iterable, start=0):\n",
    "    \"Same as `enumerate`, but returns index as 2nd element instead of 1st\"\n",
    "    return ((o,i) for i,o in enumerate(iterable, start=start))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(renumerate('abc'), (('a',0),('b',1),('c',2)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def first(x, f=None, negate=False, **kwargs):\n",
    "    \"First element of `x`, optionally filtered by `f`, or None if missing\"\n",
    "    x = iter(x)\n",
    "    if f: x = filter_ex(x, f=f, negate=negate, gen=True, **kwargs)\n",
    "    return next(x, None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(first(['a', 'b', 'c', 'd', 'e']), 'a')\n",
    "test_eq(first([False]), False)\n",
    "test_eq(first([False], noop), None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def only(o):\n",
    "    \"Return the only item of `o`, raise if `o` doesn't have exactly one item\"\n",
    "    it = iter(o)\n",
    "    try: res = next(it)\n",
    "    except StopIteration: raise ValueError('iterable has 0 items') from None\n",
    "    try: next(it)\n",
    "    except StopIteration: return res\n",
    "    raise ValueError(f'iterable has more than 1 item')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|hide\n",
    "test_fail(lambda: only([]), contains='iterable has 0 items')\n",
    "test_eq(only([0]), 0)\n",
    "test_fail(lambda: only([0,1]), contains='iterable has more than 1 item')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def nested_attr(o, attr, default=None):\n",
    "    \"Same as `getattr`, but if `attr` includes a `.`, then looks inside nested objects\"\n",
    "    try:\n",
    "        for a in attr.split(\".\"): o = getattr(o, a)\n",
    "    except AttributeError: return default\n",
    "    return o"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "a = SimpleNamespace(b=(SimpleNamespace(c=1)))\n",
    "test_eq(nested_attr(a, 'b.c'), getattr(getattr(a, 'b'), 'c'))\n",
    "test_eq(nested_attr(a, 'b.d'), None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def nested_setdefault(o, attr, default):\n",
    "    \"Same as `setdefault`, but if `attr` includes a `.`, then looks inside nested objects\"\n",
    "    attrs = attr.split('.')\n",
    "    for a in attrs[:-1]: o = o.setdefault(a, type(o)())\n",
    "    return o.setdefault(attrs[-1], default)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|hide\n",
    "o = {'e':'f'}\n",
    "test_eq(nested_setdefault(o, 'a.b.c', 'd'), 'd')\n",
    "test_eq(o, {'a':{'b':{'c':'d'}},'e':'f'})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|hide\n",
    "o = {'a':'b'}\n",
    "test_eq(nested_setdefault(o, 'a', 'c'), 'b')\n",
    "test_eq(o, {'a':'b'})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|hide\n",
    "o = {'a':{'b':'c'}}\n",
    "test_eq(nested_setdefault(o, 'a.b', 'd'), 'c')\n",
    "test_eq(o,{'a':{'b':'c'}})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def nested_callable(o, attr):\n",
    "    \"Same as `nested_attr` but if not found will return `noop`\"\n",
    "    return nested_attr(o, attr, noop)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "a = SimpleNamespace(b=(SimpleNamespace(c=1)))\n",
    "test_eq(nested_callable(a, 'b.c'), getattr(getattr(a, 'b'), 'c'))\n",
    "test_eq(nested_callable(a, 'b.d'), noop)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def _access(coll, idx):\n",
    "    if isinstance(idx,str) and hasattr(coll, idx): return getattr(coll, idx)\n",
    "    if hasattr(coll, 'get'): return coll.get(idx, None)\n",
    "    try: length = len(coll)\n",
    "    except TypeError: length = 0\n",
    "    if isinstance(idx,int) and idx<length: return coll[idx]\n",
    "    return None\n",
    "\n",
    "def _nested_idx(coll, *idxs):\n",
    "    *idxs,last_idx = idxs\n",
    "    for idx in idxs:\n",
    "        if isinstance(idx,str) and hasattr(coll, idx): coll = getattr(coll, idx)\n",
    "        else:\n",
    "            if isinstance(coll,str) or not isinstance(coll, typing.Collection): return None,None\n",
    "            coll = coll.get(idx, None) if hasattr(coll, 'get') else coll[idx] if idx<len(coll) else None\n",
    "    return coll,last_idx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def nested_idx(coll, *idxs):\n",
    "    \"Index into nested collections, dicts, etc, with `idxs`\"\n",
    "    if not coll or not idxs: return coll\n",
    "    coll,idx = _nested_idx(coll, *idxs)\n",
    "    if not coll or not idxs: return coll\n",
    "    return _access(coll, idx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "a = {'b':[1,{'c':2}]}\n",
    "test_eq(nested_idx(a, 'nope'), None)\n",
    "test_eq(nested_idx(a, 'nope', 'nup'), None)\n",
    "test_eq(nested_idx(a, 'b', 3), None)\n",
    "test_eq(nested_idx(a), a)\n",
    "test_eq(nested_idx(a, 'b'), [1,{'c':2}])\n",
    "test_eq(nested_idx(a, 'b', 1), {'c':2})\n",
    "test_eq(nested_idx(a, 'b', 1, 'c'), 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "a = SimpleNamespace(b=[1,{'c':2}])\n",
    "test_eq(nested_idx(a, 'nope'), None)\n",
    "test_eq(nested_idx(a, 'nope', 'nup'), None)\n",
    "test_eq(nested_idx(a, 'b', 3), None)\n",
    "test_eq(nested_idx(a), a)\n",
    "test_eq(nested_idx(a, 'b'), [1,{'c':2}])\n",
    "test_eq(nested_idx(a, 'b', 1), {'c':2})\n",
    "test_eq(nested_idx(a, 'b', 1, 'c'), 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def set_nested_idx(coll, value, *idxs):\n",
    "    \"Set value indexed like `nested_idx\"\n",
    "    coll,idx = _nested_idx(coll, *idxs)\n",
    "    coll[idx] = value"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "set_nested_idx(a, 3, 'b', 0)\n",
    "test_eq(nested_idx(a, 'b', 0), 3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def val2idx(x):\n",
    "    \"Dict from value to index\"\n",
    "    return {v:k for k,v in enumerate(x)}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(val2idx([1,2,3]), {3:2,1:0,2:1})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def uniqueify(x, sort=False, bidir=False, start=None):\n",
    "    \"Unique elements in `x`, optional `sort`, optional return reverse correspondence, optional prepend with elements.\"\n",
    "    res = list(dict.fromkeys(x))\n",
    "    if start is not None: res = listify(start)+res\n",
    "    if sort: res.sort()\n",
    "    return (res,val2idx(res)) if bidir else res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t = [1,1,0,5,0,3]\n",
    "test_eq(uniqueify(t),[1,0,5,3])\n",
    "test_eq(uniqueify(t, sort=True),[0,1,3,5])\n",
    "test_eq(uniqueify(t, start=[7,8,6]), [7,8,6,1,0,5,3])\n",
    "v,o = uniqueify(t, bidir=True)\n",
    "test_eq(v,[1,0,5,3])\n",
    "test_eq(o,{1:0, 0: 1, 5: 2, 3: 3})\n",
    "v,o = uniqueify(t, sort=True, bidir=True)\n",
    "test_eq(v,[0,1,3,5])\n",
    "test_eq(o,{0:0, 1: 1, 3: 2, 5: 3})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "\n",
    "# looping functions from https://github.com/willmcgugan/rich/blob/master/rich/_loop.py\n",
    "def loop_first_last(values):\n",
    "    \"Iterate and generate a tuple with a flag for first and last value.\"\n",
    "    iter_values = iter(values)\n",
    "    try: previous_value = next(iter_values)\n",
    "    except StopIteration: return\n",
    "    first = True\n",
    "    for value in iter_values:\n",
    "        yield first,False,previous_value\n",
    "        first,previous_value = False,value\n",
    "    yield first,True,previous_value"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(loop_first_last(range(3)), [(True,False,0), (False,False,1), (False,True,2)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def loop_first(values):\n",
    "    \"Iterate and generate a tuple with a flag for first value.\"\n",
    "    return ((b,o) for b,_,o in loop_first_last(values))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(loop_first(range(3)), [(True,0), (False,1), (False,2)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def loop_last(values):\n",
    "    \"Iterate and generate a tuple with a flag for last value.\"\n",
    "    return ((b,o) for _,b,o in loop_first_last(values))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(loop_last(range(3)), [(False,0), (False,1), (True,2)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def first_match(lst, f, default=None):\n",
    "    \"First element of `lst` matching predicate `f`, or `default` if none\"\n",
    "    return next((i for i,o in enumerate(lst) if f(o)), default)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "a = [0,2,4,5,6,7,10]\n",
    "test_eq(first_match(a, lambda o:o%2), 3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def last_match(lst, f, default=None):\n",
    "    \"Last element of `lst` matching predicate `f`, or `default` if none\"\n",
    "    return next((i for i in range(len(lst)-1, -1, -1) if f(lst[i])), default)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(last_match(a, lambda o:o%2), 5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## fastuple"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A tuple with extended functionality."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "num_methods = \"\"\"\n",
    "    __add__ __sub__ __mul__ __matmul__ __truediv__ __floordiv__ __mod__ __divmod__ __pow__\n",
    "    __lshift__ __rshift__ __and__ __xor__ __or__ __neg__ __pos__ __abs__\n",
    "\"\"\".split()\n",
    "rnum_methods = \"\"\"\n",
    "    __radd__ __rsub__ __rmul__ __rmatmul__ __rtruediv__ __rfloordiv__ __rmod__ __rdivmod__\n",
    "    __rpow__ __rlshift__ __rrshift__ __rand__ __rxor__ __ror__\n",
    "\"\"\".split()\n",
    "inum_methods = \"\"\"\n",
    "    __iadd__ __isub__ __imul__ __imatmul__ __itruediv__\n",
    "    __ifloordiv__ __imod__ __ipow__ __ilshift__ __irshift__ __iand__ __ixor__ __ior__\n",
    "\"\"\".split()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class fastuple(tuple):\n",
    "    \"A `tuple` with elementwise ops and more friendly __init__ behavior\"\n",
    "    def __new__(cls, x=None, *rest):\n",
    "        if x is None: x = ()\n",
    "        if not isinstance(x,tuple):\n",
    "            if len(rest): x = (x,)\n",
    "            else:\n",
    "                try: x = tuple(iter(x))\n",
    "                except TypeError: x = (x,)\n",
    "        return super().__new__(cls, x+rest if rest else x)\n",
    "    \n",
    "    def _op(self,op,*args):\n",
    "        if not isinstance(self,fastuple): self = fastuple(self)\n",
    "        return type(self)(map(op,self,*map(cycle, args)))\n",
    "\n",
    "    def mul(self,*args):\n",
    "        \"`*` is already defined in `tuple` for replicating, so use `mul` instead\"\n",
    "        return fastuple._op(self, operator.mul,*args)\n",
    "\n",
    "    def add(self,*args):\n",
    "        \"`+` is already defined in `tuple` for concat, so use `add` instead\"\n",
    "        return fastuple._op(self, operator.add,*args)\n",
    "\n",
    "def _get_op(op):\n",
    "    if isinstance(op,str): op = getattr(operator,op)\n",
    "    def _f(self,*args): return self._op(op,*args)\n",
    "    return _f\n",
    "\n",
    "for n in num_methods:\n",
    "    if not hasattr(fastuple, n) and hasattr(operator,n): setattr(fastuple,n,_get_op(n))\n",
    "\n",
    "for n in 'eq ne lt le gt ge'.split(): setattr(fastuple,n,_get_op(n))\n",
    "setattr(fastuple,'__invert__',_get_op('__not__'))\n",
    "setattr(fastuple,'max',_get_op(max))\n",
    "setattr(fastuple,'min',_get_op(min))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "---\n",
       "\n",
       "[source](https://github.com/fastai/fastcore/blob/master/fastcore/basics.py#L806){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
       "\n",
       "#### fastuple\n",
       "\n",
       ">      fastuple (x=None, *rest)\n",
       "\n",
       "*A `tuple` with elementwise ops and more friendly __init__ behavior*"
      ],
      "text/plain": [
       "---\n",
       "\n",
       "[source](https://github.com/fastai/fastcore/blob/master/fastcore/basics.py#L806){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
       "\n",
       "#### fastuple\n",
       "\n",
       ">      fastuple (x=None, *rest)\n",
       "\n",
       "*A `tuple` with elementwise ops and more friendly __init__ behavior*"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "show_doc(fastuple, title_level=4)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Friendly init behavior"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Common failure modes when trying to initialize a tuple in python:\n",
    "\n",
    "```py\n",
    "tuple(3)\n",
    "> TypeError: 'int' object is not iterable\n",
    "```\n",
    "\n",
    "or \n",
    "\n",
    "```py\n",
    "tuple(3, 4)\n",
    "> TypeError: tuple expected at most 1 arguments, got 2\n",
    "```\n",
    "\n",
    "However, `fastuple` allows you to define tuples like this and in the usual way:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(fastuple(3), (3,))\n",
    "test_eq(fastuple(3,4), (3, 4))\n",
    "test_eq(fastuple((3,4)), (3, 4))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Elementwise operations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "---\n",
       "\n",
       "[source](https://github.com/fastai/fastcore/blob/master/fastcore/basics.py#L825){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
       "\n",
       "##### fastuple.add\n",
       "\n",
       ">      fastuple.add (*args)\n",
       "\n",
       "*`+` is already defined in `tuple` for concat, so use `add` instead*"
      ],
      "text/plain": [
       "---\n",
       "\n",
       "[source](https://github.com/fastai/fastcore/blob/master/fastcore/basics.py#L825){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
       "\n",
       "##### fastuple.add\n",
       "\n",
       ">      fastuple.add (*args)\n",
       "\n",
       "*`+` is already defined in `tuple` for concat, so use `add` instead*"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "show_doc(fastuple.add, title_level=5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(fastuple.add((1,1),(2,2)), (3,3))\n",
    "test_eq_type(fastuple(1,1).add(2), fastuple(3,3))\n",
    "test_eq(fastuple('1','2').add('2'), fastuple('12','22'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "---\n",
       "\n",
       "[source](https://github.com/fastai/fastcore/blob/master/fastcore/basics.py#L821){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
       "\n",
       "##### fastuple.mul\n",
       "\n",
       ">      fastuple.mul (*args)\n",
       "\n",
       "*`*` is already defined in `tuple` for replicating, so use `mul` instead*"
      ],
      "text/plain": [
       "---\n",
       "\n",
       "[source](https://github.com/fastai/fastcore/blob/master/fastcore/basics.py#L821){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
       "\n",
       "##### fastuple.mul\n",
       "\n",
       ">      fastuple.mul (*args)\n",
       "\n",
       "*`*` is already defined in `tuple` for replicating, so use `mul` instead*"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "show_doc(fastuple.mul, title_level=5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq_type(fastuple(1,1).mul(2), fastuple(2,2))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Other Elementwise Operations"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Additionally, the following elementwise operations are available:\n",
    "- `le`: less than or equal\n",
    "- `eq`: equal\n",
    "- `gt`: greater than\n",
    "- `min`: minimum of"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(fastuple(3,1).le(1), (False, True))\n",
    "test_eq(fastuple(3,1).eq(1), (False, True))\n",
    "test_eq(fastuple(3,1).gt(1), (True, False))\n",
    "test_eq(fastuple(3,1).min(2), (2,1))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can also do other elementwise operations like negate a `fastuple`, or subtract two `fastuple`s:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(-fastuple(1,2), (-1,-2))\n",
    "test_eq(~fastuple(1,0,1), (False,True,False))\n",
    "\n",
    "test_eq(fastuple(1,1)-fastuple(2,2), (-1,-1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(type(fastuple(1)), fastuple)\n",
    "test_eq_type(fastuple(1,2), fastuple(1,2))\n",
    "test_ne(fastuple(1,2), fastuple(1,3))\n",
    "test_eq(fastuple(), ())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Functions on Functions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Utilities for functional programming or for defining, modifying, or debugging functions. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class _Arg:\n",
    "    def __init__(self,i): self.i = i\n",
    "arg0 = _Arg(0)\n",
    "arg1 = _Arg(1)\n",
    "arg2 = _Arg(2)\n",
    "arg3 = _Arg(3)\n",
    "arg4 = _Arg(4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class bind:\n",
    "    \"Same as `partial`, except you can use `arg0` `arg1` etc param placeholders\"\n",
    "    def __init__(self, func, *pargs, **pkwargs):\n",
    "        self.func,self.pargs,self.pkwargs = func,pargs,pkwargs\n",
    "        self.maxi = max((x.i for x in pargs if isinstance(x, _Arg)), default=-1)\n",
    "\n",
    "    def __call__(self, *args, **kwargs):\n",
    "        args = list(args)\n",
    "        kwargs = {**self.pkwargs,**kwargs}\n",
    "        for k,v in kwargs.items():\n",
    "            if isinstance(v,_Arg): kwargs[k] = args.pop(v.i)\n",
    "        fargs = [args[x.i] if isinstance(x, _Arg) else x for x in self.pargs] + args[self.maxi+1:]\n",
    "        return self.func(*fargs, **kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "---\n",
       "\n",
       "[source](https://github.com/fastai/fastcore/blob/master/fastcore/basics.py#L852){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
       "\n",
       "### bind\n",
       "\n",
       ">      bind (func, *pargs, **pkwargs)\n",
       "\n",
       "*Same as `partial`, except you can use `arg0` `arg1` etc param placeholders*"
      ],
      "text/plain": [
       "---\n",
       "\n",
       "[source](https://github.com/fastai/fastcore/blob/master/fastcore/basics.py#L852){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
       "\n",
       "### bind\n",
       "\n",
       ">      bind (func, *pargs, **pkwargs)\n",
       "\n",
       "*Same as `partial`, except you can use `arg0` `arg1` etc param placeholders*"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "show_doc(bind, title_level=3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`bind` is the same as `partial`, but also allows you to reorder positional arguments using variable name(s) `arg{i}` where i refers to the zero-indexed positional argument. `bind` as implemented currently only supports reordering of up to the first 5 positional arguments.\n",
    "\n",
    "Consider the function `myfunc` below, which has 3 positional arguments.  These arguments can be referenced as `arg0`, `arg1`, and `arg1`, respectively.  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def myfn(a,b,c,d=1,e=2): return(a,b,c,d,e)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In the below example we bind the positional arguments of `myfn` as follows:\n",
    "\n",
    "- The second input `14`, referenced by `arg1`, is substituted for the first positional argument.\n",
    "- We supply a default value of `17` for the second positional argument.\n",
    "- The first input `19`, referenced by `arg0`, is subsituted for the third positional argument.  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(bind(myfn, arg1, 17, arg0, e=3)(19,14), (14,17,19,1,3))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this next example:\n",
    "\n",
    "- We set the default value to `17` for the first positional argument.\n",
    "- The first input `19` refrenced by `arg0`, becomes the second positional argument.\n",
    "- The second input `14` becomes the third positional argument.\n",
    "- We override the default the value for named argument `e` to `3`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(bind(myfn, 17, arg0, e=3)(19,14), (17,19,14,1,3))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is an example of using `bind` like `partial` and do not reorder any arguments:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(bind(myfn)(17,19,14), (17,19,14,1,2))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`bind` can also be used to change default values.  In the below example, we use the first input `3` to override the default value of the named argument `e`, and supply default values for the first three positional arguments:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(bind(myfn, 17,19,14,e=arg0)(3), (17,19,14,1,3))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def mapt(func, *iterables):\n",
    "    \"Tuplified `map`\"\n",
    "    return tuple(map(func, *iterables))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t = [0,1,2,3]\n",
    "test_eq(mapt(operator.neg, t), (0,-1,-2,-3))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def map_ex(iterable, f, *args, gen=False, **kwargs):\n",
    "    \"Like `map`, but use `bind`, and supports `str` and indexing\"\n",
    "    g = (bind(f,*args,**kwargs) if callable(f)\n",
    "         else f.format if isinstance(f,str)\n",
    "         else f.__getitem__)\n",
    "    res = map(g, iterable)\n",
    "    if gen: return res\n",
    "    return list(res)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(map_ex(t,operator.neg), [0,-1,-2,-3])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If `f` is a string then it is treated as a format string to create the mapping:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(map_ex(t, '#{}#'), ['#0#','#1#','#2#','#3#'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If `f` is a dictionary (or anything supporting `__getitem__`) then it is indexed to create the mapping:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(map_ex(t, list('abcd')), list('abcd'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can also pass the same `arg` params that `bind` accepts:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def f(a=None,b=None): return b\n",
    "test_eq(map_ex(t, f, b=arg0), range(4))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def compose(*funcs, order=None):\n",
    "    \"Create a function that composes all functions in `funcs`, passing along remaining `*args` and `**kwargs` to all\"\n",
    "    funcs = listify(funcs)\n",
    "    if len(funcs)==0: return noop\n",
    "    if len(funcs)==1: return funcs[0]\n",
    "    if order is not None: funcs = sorted_ex(funcs, key=order)\n",
    "    def _inner(x, *args, **kwargs):\n",
    "        for f in funcs: x = f(x, *args, **kwargs)\n",
    "        return x\n",
    "    return _inner"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "f1 = lambda o,p=0: (o*2)+p\n",
    "f2 = lambda o,p=1: (o+1)/p\n",
    "test_eq(f2(f1(3)), compose(f1,f2)(3))\n",
    "test_eq(f2(f1(3,p=3),p=3), compose(f1,f2)(3,p=3))\n",
    "test_eq(f2(f1(3,  3),  3), compose(f1,f2)(3,  3))\n",
    "\n",
    "f1.order = 1\n",
    "test_eq(f1(f2(3)), compose(f1,f2, order=\"order\")(3))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def maps(*args, retain=noop):\n",
    "    \"Like `map`, except funcs are composed first\"\n",
    "    f = compose(*args[:-1])\n",
    "    def _f(b): return retain(f(b), b)\n",
    "    return map(_f, args[-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(maps([1]), [1])\n",
    "test_eq(maps(operator.neg, [1,2]), [-1,-2])\n",
    "test_eq(maps(operator.neg, operator.neg, [1,2]), [1,2])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def partialler(f, *args, order=None, **kwargs):\n",
    "    \"Like `functools.partial` but also copies over docstring\"\n",
    "    fnew = partial(f,*args,**kwargs)\n",
    "    fnew.__doc__ = f.__doc__\n",
    "    if order is not None: fnew.order=order\n",
    "    elif hasattr(f,'order'): fnew.order=f.order\n",
    "    return fnew"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _f(x,a=1):\n",
    "    \"test func\"\n",
    "    return x-a\n",
    "_f.order=1\n",
    "\n",
    "f = partialler(_f, 2)\n",
    "test_eq(f.order, 1)\n",
    "test_eq(f(3), -1)\n",
    "f = partialler(_f, a=2, order=3)\n",
    "test_eq(f.__doc__, \"test func\")\n",
    "test_eq(f.order, 3)\n",
    "test_eq(f(3), _f(3,2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class partial0:\n",
    "    \"Like `partialler`, but args passed to callable are inserted at started, instead of at end\"\n",
    "    def __init__(self, f, *args, order=None, **kwargs):\n",
    "        self.f,self.args,self.kwargs = f,args,kwargs\n",
    "        self.order = ifnone(order, getattr(f,'order',None))\n",
    "        self.__doc__ = f.__doc__\n",
    "\n",
    "    def __call__(self, *args, **kwargs): return self.f(*args, *self.args, **kwargs, **self.kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "f = partial0(_f, 2)\n",
    "test_eq(f.order, 1)\n",
    "test_eq(f(3), 1) # NB: different to `partialler` example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def instantiate(t):\n",
    "    \"Instantiate `t` if it's a type, otherwise do nothing\"\n",
    "    return t() if isinstance(t, type) else t"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq_type(instantiate(int), 0)\n",
    "test_eq_type(instantiate(1), 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def _using_attr(f, attr, x): return f(getattr(x,attr))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def using_attr(f, attr):\n",
    "    \"Construct a function which applies `f` to the argument's attribute `attr`\"\n",
    "    return partial(_using_attr, f, attr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t = Path('/a/b.txt')\n",
    "f = using_attr(str.upper, 'name')\n",
    "test_eq(f(t), 'B.TXT')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Self (with an _uppercase_ S)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A Concise Way To Create Lambdas"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class _Self:\n",
    "    \"An alternative to `lambda` for calling methods on passed object.\"\n",
    "    def __init__(self): self.nms,self.args,self.kwargs,self.ready = [],[],[],True\n",
    "    def __repr__(self): return f'self: {self.nms}({self.args}, {self.kwargs})'\n",
    "\n",
    "    def __call__(self, *args, **kwargs):\n",
    "        if self.ready:\n",
    "            x = args[0]\n",
    "            for n,a,k in zip(self.nms,self.args,self.kwargs):\n",
    "                x = getattr(x,n)\n",
    "                if callable(x) and a is not None: x = x(*a, **k)\n",
    "            return x\n",
    "        else:\n",
    "            self.args.append(args)\n",
    "            self.kwargs.append(kwargs)\n",
    "            self.ready = True\n",
    "            return self\n",
    "\n",
    "    def __getattr__(self,k):\n",
    "        if not self.ready:\n",
    "            self.args.append(None)\n",
    "            self.kwargs.append(None)\n",
    "        self.nms.append(k)\n",
    "        self.ready = False\n",
    "        return self\n",
    "\n",
    "    def _call(self, *args, **kwargs):\n",
    "        self.args,self.kwargs,self.nms = [args],[kwargs],['__call__']\n",
    "        self.ready = True\n",
    "        return self"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class _SelfCls:\n",
    "    def __getattr__(self,k): return getattr(_Self(),k)\n",
    "    def __getitem__(self,i): return self.__getattr__('__getitem__')(i)\n",
    "    def __call__(self,*args,**kwargs): return self.__getattr__('_call')(*args,**kwargs)\n",
    "\n",
    "Self = _SelfCls()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "_all_ = ['Self']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is a concise way to create lambdas that are calling methods on an object (note the capitalization!)\n",
    "\n",
    "`Self.sum()`, for instance, is a shortcut for `lambda o: o.sum()`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "f = Self.sum()\n",
    "x = np.array([3.,1])\n",
    "test_eq(f(x), 4.)\n",
    "\n",
    "# This is equivalent to above\n",
    "f = lambda o: o.sum()\n",
    "x = np.array([3.,1])\n",
    "test_eq(f(x), 4.)\n",
    "\n",
    "f = Self.argmin()\n",
    "arr = np.array([1,2,3,4,5])\n",
    "test_eq(f(arr), arr.argmin())\n",
    "\n",
    "f = Self.sum().is_integer()\n",
    "x = np.array([3.,1])\n",
    "test_eq(f(x), True)\n",
    "\n",
    "f = Self.sum().real.is_integer()\n",
    "x = np.array([3.,1])\n",
    "test_eq(f(x), True)\n",
    "\n",
    "f = Self.imag()\n",
    "test_eq(f(3), 0)\n",
    "\n",
    "f = Self[1]\n",
    "test_eq(f(x), 1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`Self` is also callable, which creates a function which calls any function passed to it, using the arguments passed to `Self`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[5, 2]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def f(a, b=3): return a+b+2\n",
    "def g(a, b=3): return a*b\n",
    "fg = Self(1,b=2)\n",
    "list(map(fg, [f,g]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Patching"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def copy_func(f):\n",
    "    \"Copy a non-builtin function (NB `copy.copy` does not work for this)\"\n",
    "    if not isinstance(f,FunctionType): return copy(f)\n",
    "    fn = FunctionType(f.__code__, f.__globals__, f.__name__, f.__defaults__, f.__closure__)\n",
    "    fn.__kwdefaults__ = f.__kwdefaults__\n",
    "    fn.__dict__.update(f.__dict__)\n",
    "    fn.__annotations__.update(f.__annotations__)\n",
    "    fn.__qualname__ = f.__qualname__\n",
    "    return fn"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Sometimes it may be desirable to make a copy of a function that doesn't point to the original object.  When you use Python's built in `copy.copy` or `copy.deepcopy` to copy a function, you get a reference to the original object:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy as cp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def foo(): pass\n",
    "a = cp.copy(foo)\n",
    "b = cp.deepcopy(foo)\n",
    "\n",
    "a.someattr = 'hello' # since a and b point at the same object, updating a will update b\n",
    "test_eq(b.someattr, 'hello')\n",
    "\n",
    "assert a is foo and b is foo"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "However, with `copy_func`, you can retrieve a copy of a function without a reference to the original object:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "c = copy_func(foo) # c is an indpendent object\n",
    "assert c is not foo"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def g(x, *, y=3): return x+y\n",
    "test_eq(copy_func(g)(4), 7)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class _clsmethod:\n",
    "    def __init__(self, f): self.f = f\n",
    "    def __get__(self, _, f_cls): return MethodType(self.f, f_cls)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def patch_to(cls, as_prop=False, cls_method=False):\n",
    "    \"Decorator: add `f` to `cls`\"\n",
    "    if not isinstance(cls, (tuple,list)): cls=(cls,)\n",
    "    def _inner(f):\n",
    "        for c_ in cls:\n",
    "            nf = copy_func(f)\n",
    "            nm = f.__name__\n",
    "            # `functools.update_wrapper` when passing patched function to `Pipeline`, so we do it manually\n",
    "            for o in functools.WRAPPER_ASSIGNMENTS: setattr(nf, o, getattr(f,o))\n",
    "            nf.__qualname__ = f\"{c_.__name__}.{nm}\"\n",
    "            if cls_method: setattr(c_, nm, _clsmethod(nf))\n",
    "            else:\n",
    "                if as_prop: setattr(c_, nm, property(nf))\n",
    "                else:\n",
    "                    onm = '_orig_'+nm\n",
    "                    if hasattr(c_, nm) and not hasattr(c_, onm): setattr(c_, onm, getattr(c_, nm))\n",
    "                    setattr(c_, nm, nf)\n",
    "        # Avoid clobbering existing functions\n",
    "        return globals().get(nm, builtins.__dict__.get(nm, None))\n",
    "    return _inner"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The `@patch_to` decorator allows you to [monkey patch](https://stackoverflow.com/questions/5626193/what-is-monkey-patching) a function into a class as a method:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class _T3(int): pass  \n",
    "\n",
    "@patch_to(_T3)\n",
    "def func1(self, a): return self+a\n",
    "\n",
    "t = _T3(1) # we initialized `t` to a type int = 1\n",
    "test_eq(t.func1(2), 3) # we add 2 to `t`, so 2 + 1 = 3"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can access instance properties in the usual way via `self`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class _T4():\n",
    "    def __init__(self, g): self.g = g\n",
    "        \n",
    "@patch_to(_T4)\n",
    "def greet(self, x): return self.g + x\n",
    "        \n",
    "t = _T4('hello ') # this sets self.g = 'hello '\n",
    "test_eq(t.greet('world'), 'hello world') #t.greet('world') will append 'world' to 'hello '"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can instead specify that the method should be a class method by setting `cls_method=True`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class _T5(int): attr = 3 # attr is a class attribute we will access in a later method\n",
    "    \n",
    "@patch_to(_T5, cls_method=True)\n",
    "def func(cls, x): return cls.attr + x # you can access class attributes in the normal way\n",
    "\n",
    "test_eq(_T5.func(4), 7)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Additionally you can specify that the function you want to patch should be a class attribute with `as_prop=True`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@patch_to(_T5, as_prop=True)\n",
    "def add_ten(self): return self + 10\n",
    "\n",
    "t = _T5(4)\n",
    "test_eq(t.add_ten, 14)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Instead of passing one class to the `@patch_to` decorator, you can pass multiple classes in a tuple to simulteanously patch more than one class with the same method:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class _T6(int): pass\n",
    "class _T7(int): pass\n",
    "\n",
    "@patch_to((_T6,_T7))\n",
    "def func_mult(self, a): return self*a\n",
    "\n",
    "t = _T6(2)\n",
    "test_eq(t.func_mult(4), 8)\n",
    "t = _T7(2)\n",
    "test_eq(t.func_mult(4), 8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def patch(f=None, *, as_prop=False, cls_method=False):\n",
    "    \"Decorator: add `f` to the first parameter's class (based on f's type annotations)\"\n",
    "    if f is None: return partial(patch, as_prop=as_prop, cls_method=cls_method)\n",
    "    ann,glb,loc = get_annotations_ex(f)\n",
    "    cls = union2tuple(eval_type(ann.pop('cls') if cls_method else next(iter(ann.values())), glb, loc))\n",
    "    return patch_to(cls, as_prop=as_prop, cls_method=cls_method)(f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`@patch` is an alternative to `@patch_to` that allows you similarly monkey patch class(es) by using [type annotations](https://docs.python.org/3/library/typing.html):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class _T8(int): pass  \n",
    "\n",
    "@patch\n",
    "def func(self:_T8, a): return self+a\n",
    "\n",
    "t = _T8(1)  # we initilized `t` to a type int = 1\n",
    "test_eq(t.func(3), 4) # we add 3 to `t`, so 3 + 1 = 4\n",
    "test_eq(t.func.__qualname__, '_T8.func')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Similarly to `patch_to`, you can supply a union of classes instead of a single class in your type annotations to patch multiple classes:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class _T9(int): pass \n",
    "\n",
    "@patch\n",
    "def func2(x:_T8|_T9, a): return x*a # will patch both _T8 and _T9\n",
    "\n",
    "t = _T8(2)\n",
    "test_eq(t.func2(4), 8)\n",
    "test_eq(t.func2.__qualname__, '_T8.func2')\n",
    "\n",
    "t = _T9(2)\n",
    "test_eq(t.func2(4), 8)\n",
    "test_eq(t.func2.__qualname__, '_T9.func2')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Just like `patch_to` decorator you can use `as_prop` and `cls_method` parameters with `patch` decorator:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@patch(as_prop=True)\n",
    "def add_ten(self:_T5): return self + 10\n",
    "\n",
    "t = _T5(4)\n",
    "test_eq(t.add_ten, 14)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class _T5(int): attr = 3 # attr is a class attribute we will access in a later method\n",
    "    \n",
    "@patch(cls_method=True)\n",
    "def func(cls:_T5, x): return cls.attr + x # you can access class attributes in the normal way\n",
    "\n",
    "test_eq(_T5.func(4), 7)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def patch_property(f):\n",
    "    \"Deprecated; use `patch(as_prop=True)` instead\"\n",
    "    warnings.warn(\"`patch_property` is deprecated and will be removed; use `patch(as_prop=True)` instead\")\n",
    "    cls = next(iter(f.__annotations__.values()))\n",
    "    return patch_to(cls, as_prop=True)(f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Patching `classmethod` shouldn't affect how python's inheritance works"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class FastParent: pass\n",
    "\n",
    "@patch(cls_method=True)\n",
    "def type_cls(cls: FastParent): return cls\n",
    "\n",
    "class FastChild(FastParent): pass\n",
    "\n",
    "parent = FastParent()\n",
    "test_eq(parent.type_cls(), FastParent)\n",
    "\n",
    "child = FastChild()\n",
    "test_eq(child.type_cls(), FastChild)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Other Helpers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def compile_re(pat):\n",
    "    \"Compile `pat` if it's not None\"\n",
    "    return None if pat is None else re.compile(pat)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert compile_re(None) is None\n",
    "assert compile_re('a').match('ab')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class ImportEnum(enum.Enum):\n",
    "    \"An `Enum` that can have its values imported\"\n",
    "    @classmethod\n",
    "    def imports(cls):\n",
    "        g = sys._getframe(1).f_locals\n",
    "        for o in cls: g[o.name]=o"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "---\n",
       "\n",
       "[source](https://github.com/fastai/fastcore/blob/master/fastcore/basics.py#L1024){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
       "\n",
       "#### ImportEnum\n",
       "\n",
       ">      ImportEnum (value, names=None, module=None, qualname=None, type=None,\n",
       ">                  start=1, boundary=None)\n",
       "\n",
       "*An `Enum` that can have its values imported*"
      ],
      "text/plain": [
       "---\n",
       "\n",
       "[source](https://github.com/fastai/fastcore/blob/master/fastcore/basics.py#L1024){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
       "\n",
       "#### ImportEnum\n",
       "\n",
       ">      ImportEnum (value, names=None, module=None, qualname=None, type=None,\n",
       ">                  start=1, boundary=None)\n",
       "\n",
       "*An `Enum` that can have its values imported*"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "show_doc(ImportEnum, title_level=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "_T = ImportEnum('_T', {'foobar':1, 'goobar':2})\n",
    "_T.imports()\n",
    "test_eq(foobar, _T.foobar)\n",
    "test_eq(goobar, _T.goobar)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class StrEnum(str,ImportEnum):\n",
    "    \"An `ImportEnum` that behaves like a `str`\"\n",
    "    def __str__(self): return self.name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "---\n",
       "\n",
       "[source](https://github.com/fastai/fastcore/blob/master/fastcore/basics.py#L1032){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
       "\n",
       "#### StrEnum\n",
       "\n",
       ">      StrEnum (value, names=None, module=None, qualname=None, type=None,\n",
       ">               start=1, boundary=None)\n",
       "\n",
       "*An `ImportEnum` that behaves like a `str`*"
      ],
      "text/plain": [
       "---\n",
       "\n",
       "[source](https://github.com/fastai/fastcore/blob/master/fastcore/basics.py#L1032){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
       "\n",
       "#### StrEnum\n",
       "\n",
       ">      StrEnum (value, names=None, module=None, qualname=None, type=None,\n",
       ">               start=1, boundary=None)\n",
       "\n",
       "*An `ImportEnum` that behaves like a `str`*"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "show_doc(StrEnum, title_level=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def str_enum(name, *vals):\n",
    "    \"Simplified creation of `StrEnum` types\"\n",
    "    return StrEnum(name, {o:o for o in vals})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class ValEnum(str,ImportEnum):\n",
    "    \"An `ImportEnum` that stringifies using values\"\n",
    "    def __str__(self): return self.value"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "---\n",
       "\n",
       "#### ValEnum\n",
       "\n",
       ">      ValEnum (value, names=None, module=None, qualname=None, type=None,\n",
       ">               start=1, boundary=None)\n",
       "\n",
       "*An `ImportEnum` that stringifies using values*"
      ],
      "text/plain": [
       "---\n",
       "\n",
       "#### ValEnum\n",
       "\n",
       ">      ValEnum (value, names=None, module=None, qualname=None, type=None,\n",
       ">               start=1, boundary=None)\n",
       "\n",
       "*An `ImportEnum` that stringifies using values*"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "show_doc(ValEnum, title_level=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "a A\n"
     ]
    }
   ],
   "source": [
    "_T = str_enum('_T', 'a', 'b')\n",
    "test_eq(f'{_T.a}', 'a')\n",
    "test_eq(_T.a, 'a')\n",
    "test_eq(list(_T.__members__), ['a','b'])\n",
    "print(_T.a, _T.a.upper())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class Stateful:\n",
    "    \"A base class/mixin for objects that should not serialize all their state\"\n",
    "    _stateattrs=()\n",
    "    def __init__(self,*args,**kwargs):\n",
    "        self._init_state()\n",
    "        super().__init__(*args,**kwargs) # required for mixin usage\n",
    "\n",
    "    def __getstate__(self):\n",
    "        return {k:v for k,v in self.__dict__.items()\n",
    "                if k not in self._stateattrs+('_state',)}\n",
    "\n",
    "    def __setstate__(self, state):\n",
    "        self.__dict__.update(state)\n",
    "        self._init_state()\n",
    "\n",
    "    def _init_state(self):\n",
    "        \"Override for custom init and deserialization logic\"\n",
    "        self._state = {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "---\n",
       "\n",
       "[source](https://github.com/fastai/fastcore/blob/master/fastcore/basics.py#L1042){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
       "\n",
       "#### Stateful\n",
       "\n",
       ">      Stateful (*args, **kwargs)\n",
       "\n",
       "*A base class/mixin for objects that should not serialize all their state*"
      ],
      "text/plain": [
       "---\n",
       "\n",
       "[source](https://github.com/fastai/fastcore/blob/master/fastcore/basics.py#L1042){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
       "\n",
       "#### Stateful\n",
       "\n",
       ">      Stateful (*args, **kwargs)\n",
       "\n",
       "*A base class/mixin for objects that should not serialize all their state*"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "show_doc(Stateful, title_level=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class _T(Stateful):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.a=1\n",
    "        self._state['test']=2\n",
    "\n",
    "t = _T()\n",
    "t2 = pickle.loads(pickle.dumps(t))\n",
    "test_eq(t.a,1)\n",
    "test_eq(t._state['test'],2)\n",
    "test_eq(t2.a,1)\n",
    "test_eq(t2._state,{})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Override `_init_state` to do any necessary setup steps that are required during `__init__` or during deserialization (e.g. `pickle.load`). Here's an example of how `Stateful` simplifies the official Python example for [Handling Stateful Objects](https://docs.python.org/3/library/pickle.html#handling-stateful-objects)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class TextReader(Stateful):\n",
    "    \"\"\"Print and number lines in a text file.\"\"\"\n",
    "    _stateattrs=('file',)\n",
    "    def __init__(self, filename):\n",
    "        self.filename,self.lineno = filename,0\n",
    "        super().__init__()\n",
    "\n",
    "    def readline(self):\n",
    "        self.lineno += 1\n",
    "        line = self.file.readline()\n",
    "        if line: return f\"{self.lineno}: {line.strip()}\"\n",
    "\n",
    "    def _init_state(self):\n",
    "        self.file = open(self.filename)\n",
    "        for _ in range(self.lineno): self.file.readline()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1: {\n",
      "2: \"cells\": [\n",
      "3: {\n"
     ]
    }
   ],
   "source": [
    "reader = TextReader(\"00_test.ipynb\")\n",
    "print(reader.readline())\n",
    "print(reader.readline())\n",
    "\n",
    "new_reader = pickle.loads(pickle.dumps(reader))\n",
    "print(reader.readline())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class NotStr(GetAttr):\n",
    "    \"Behaves like a `str`, but isn't an instance of one\"\n",
    "    _default = 's'\n",
    "    def __init__(self, s): self.s = s.s if isinstance(s, NotStr) else s\n",
    "    def __repr__(self): return repr(self.s)\n",
    "    def __str__(self): return self.s\n",
    "    def __add__(self, b): return NotStr(self.s+str(b))\n",
    "    def __mul__(self, b): return NotStr(self.s*b)\n",
    "    def __len__(self): return len(self.s)\n",
    "    def __eq__(self, b): return self.s==b.s if isinstance(b, NotStr) else b\n",
    "    def __lt__(self, b): return self.s<b\n",
    "    def __hash__(self): return hash(self.s)\n",
    "    def __bool__(self): return bool(self.s)\n",
    "    def __contains__(self, b): return b in self.s\n",
    "    def __iter__(self): return iter(self.s)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "s = NotStr(\"hello\")\n",
    "assert not isinstance(s, str)\n",
    "test_eq(s, 'hello')\n",
    "test_eq(s*2, 'hellohello')\n",
    "test_eq(len(s), 5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class PrettyString(str):\n",
    "    \"Little hack to get strings to show properly in Jupyter.\"\n",
    "    def __repr__(self): return self"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "---\n",
       "\n",
       "[source](https://github.com/fastai/fastcore/blob/master/fastcore/basics.py#L1015){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
       "\n",
       "#### PrettyString\n",
       "\n",
       "\n",
       "\n",
       "Little hack to get strings to show properly in Jupyter."
      ],
      "text/plain": [
       "---\n",
       "\n",
       "[source](https://github.com/fastai/fastcore/blob/master/fastcore/basics.py#L1015){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
       "\n",
       "#### PrettyString\n",
       "\n",
       "\n",
       "\n",
       "Little hack to get strings to show properly in Jupyter."
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "show_doc(PrettyString, title_level=4)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Allow strings with special characters to render properly in Jupyter.  Without calling `print()` strings with special characters are displayed like so:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'a string\\nwith\\nnew\\nlines and\\ttabs'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "with_special_chars='a string\\nwith\\nnew\\nlines and\\ttabs'\n",
    "with_special_chars"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can correct this with `PrettyString`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "a string\n",
       "with\n",
       "new\n",
       "lines and\ttabs"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "PrettyString(with_special_chars)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def even_mults(start, stop, n):\n",
    "    \"Build log-stepped array from `start` to `stop` in `n` steps.\"\n",
    "    if n==1: return stop\n",
    "    mult = stop/start\n",
    "    step = mult**(1/(n-1))\n",
    "    return [start*(step**i) for i in range(n)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(even_mults(2,8,3), [2,4,8])\n",
    "test_eq(even_mults(2,32,5), [2,4,8,16,32])\n",
    "test_eq(even_mults(2,8,1), 8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def num_cpus():\n",
    "    \"Get number of cpus\"\n",
    "    try:                   return len(os.sched_getaffinity(0))\n",
    "    except AttributeError: return os.cpu_count()\n",
    "\n",
    "defaults.cpus = num_cpus()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "8"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "num_cpus()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def add_props(f, g=None, n=2):\n",
    "    \"Create properties passing each of `range(n)` to f\"\n",
    "    if g is None: return (property(partial(f,i)) for i in range(n))\n",
    "    return (property(partial(f,i), partial(g,i)) for i in range(n))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class _T(): a,b = add_props(lambda i,x:i*2)\n",
    "\n",
    "t = _T()\n",
    "test_eq(t.a,0)\n",
    "test_eq(t.b,2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class _T(): \n",
    "    def __init__(self, v): self.v=v\n",
    "    def _set(i, self, v): self.v[i] = v\n",
    "    a,b = add_props(lambda i,x: x.v[i], _set)\n",
    "\n",
    "t = _T([0,2])\n",
    "test_eq(t.a,0)\n",
    "test_eq(t.b,2)\n",
    "t.a = t.a+1\n",
    "t.b = 3\n",
    "test_eq(t.a,1)\n",
    "test_eq(t.b,3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def _typeerr(arg, val, typ): return TypeError(f\"{arg}=={val} not {typ}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def typed(f):\n",
    "    \"Decorator to check param and return types at runtime\"\n",
    "    names = f.__code__.co_varnames\n",
    "    anno = annotations(f)\n",
    "    ret = anno.pop('return',None)\n",
    "    def _f(*args,**kwargs):\n",
    "        kw = {**kwargs}\n",
    "        if len(anno) > 0:\n",
    "            for i,arg in enumerate(args): kw[names[i]] = arg\n",
    "            for k,v in kw.items():\n",
    "                if k in anno and not isinstance(v,anno[k]): raise _typeerr(k, v, anno[k])\n",
    "        res = f(*args,**kwargs)\n",
    "        if ret is not None and not isinstance(res,ret): raise _typeerr(\"return\", res, ret)\n",
    "        return res\n",
    "    return functools.update_wrapper(_f, f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`typed` validates argument types at **runtime**.  This is in contrast to [MyPy](http://mypy-lang.org/) which only offers static type checking.\n",
    "\n",
    "For example, a `TypeError` will be raised if we try to pass an integer into the first argument of the below function: "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@typed\n",
    "def discount(price:int, pct:float): \n",
    "    return (1-pct) * price\n",
    "\n",
    "with ExceptionExpected(TypeError): discount(100.0, .1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can also optionally allow multiple types by enumarating the types in a tuple as illustrated below:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def discount(price:int|float, pct:float): \n",
    "    return (1-pct) * price\n",
    "\n",
    "assert 90.0 == discount(100.0, .1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@typed\n",
    "def foo(a:int, b:str='a'): return a\n",
    "test_eq(foo(1, '2'), 1)\n",
    "\n",
    "with ExceptionExpected(TypeError): foo(1,2)\n",
    "\n",
    "@typed\n",
    "def foo()->str: return 1\n",
    "with ExceptionExpected(TypeError): foo()\n",
    "\n",
    "@typed\n",
    "def foo()->str: return '1'\n",
    "assert foo()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`typed` works with classes, too:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Foo:\n",
    "    @typed\n",
    "    def __init__(self, a:int, b: int, c:str): pass\n",
    "    @typed\n",
    "    def test(cls, d:str): return d\n",
    "\n",
    "with ExceptionExpected(TypeError): Foo(1, 2, 3) \n",
    "with ExceptionExpected(TypeError): Foo(1,2, 'a string').test(10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def exec_new(code):\n",
    "    \"Execute `code` in a new environment and return it\"\n",
    "    pkg = None if __name__=='__main__' else Path().cwd().name\n",
    "    g = {'__name__': __name__, '__package__': pkg}\n",
    "    exec(code, g)\n",
    "    return g"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "g = exec_new('a=1')\n",
    "test_eq(g['a'], 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def exec_import(mod, sym):\n",
    "    \"Import `sym` from `mod` in a new environment\"\n",
    "#     pref = '' if __name__=='__main__' or mod[0]=='.' else '.'\n",
    "    return exec_new(f'from {mod} import {sym}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def str2bool(s):\n",
    "    \"Case-insensitive convert string `s` too a bool (`y`,`yes`,`t`,`true`,`on`,`1`->`True`)\"\n",
    "    if not isinstance(s,str): return bool(s)\n",
    "    if not s: return False\n",
    "    s = s.lower()\n",
    "    if s in ('y', 'yes', 't', 'true', 'on', '1'): return 1\n",
    "    elif s in ('n', 'no', 'f', 'false', 'off', '0'): return 0\n",
    "    else: raise ValueError()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values are 'n', 'no', 'f', 'false', 'off', and '0'.  Raises `ValueError` if 'val' is anything else."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for o in \"y YES t True on 1\".split(): assert str2bool(o)\n",
    "for o in \"n no FALSE off 0\".split(): assert not str2bool(o)\n",
    "for o in 0,None,'',False: assert not str2bool(o)\n",
    "for o in 1,True: assert str2bool(o)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Notebook functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "---\n",
       "\n",
       "### ipython_shell\n",
       "\n",
       ">      ipython_shell ()\n",
       "\n",
       "Same as `get_ipython` but returns `False` if not in IPython"
      ],
      "text/plain": [
       "---\n",
       "\n",
       "### ipython_shell\n",
       "\n",
       ">      ipython_shell ()\n",
       "\n",
       "Same as `get_ipython` but returns `False` if not in IPython"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "show_doc(ipython_shell)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "---\n",
       "\n",
       "### in_ipython\n",
       "\n",
       ">      in_ipython ()\n",
       "\n",
       "Check if code is running in some kind of IPython environment"
      ],
      "text/plain": [
       "---\n",
       "\n",
       "### in_ipython\n",
       "\n",
       ">      in_ipython ()\n",
       "\n",
       "Check if code is running in some kind of IPython environment"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "show_doc(in_ipython)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "---\n",
       "\n",
       "### in_colab\n",
       "\n",
       ">      in_colab ()\n",
       "\n",
       "Check if the code is running in Google Colaboratory"
      ],
      "text/plain": [
       "---\n",
       "\n",
       "### in_colab\n",
       "\n",
       ">      in_colab ()\n",
       "\n",
       "Check if the code is running in Google Colaboratory"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "show_doc(in_colab)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "---\n",
       "\n",
       "### in_jupyter\n",
       "\n",
       ">      in_jupyter ()\n",
       "\n",
       "Check if the code is running in a jupyter notebook"
      ],
      "text/plain": [
       "---\n",
       "\n",
       "### in_jupyter\n",
       "\n",
       ">      in_jupyter ()\n",
       "\n",
       "Check if the code is running in a jupyter notebook"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "show_doc(in_jupyter)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "---\n",
       "\n",
       "### in_notebook\n",
       "\n",
       ">      in_notebook ()\n",
       "\n",
       "Check if the code is running in a jupyter notebook"
      ],
      "text/plain": [
       "---\n",
       "\n",
       "### in_notebook\n",
       "\n",
       ">      in_notebook ()\n",
       "\n",
       "Check if the code is running in a jupyter notebook"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "show_doc(in_notebook)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "These variables are available as booleans in `fastcore.basics` as `IN_IPYTHON`, `IN_JUPYTER`, `IN_COLAB` and `IN_NOTEBOOK`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(True, True, False, True)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "IN_IPYTHON, IN_JUPYTER, IN_COLAB, IN_NOTEBOOK"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Export -"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|hide\n",
    "import nbdev; nbdev.nbdev_export()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "jupytext": {
   "split_at_heading": true
  },
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
